0%

Triton 的安装与使用

二进制分析框架 Triton 主要用于分析和检查二进制文件,包括可执行文件、动态链接库等

它的主要功能包括:

  1. 反汇编和分析:
    • Triton 可以反汇编二进制文件,生成汇编代码,供安全研究人员分析
    • 它支持多种指令集架构,如 x86、ARM 等
  2. 自动化分析:
    • Triton 提供了丰富的 API,支持开发者编写自动化的二进制分析脚本
    • 这些脚本可用于检测二进制文件中的漏洞、恶意代码等安全隐患
  3. 插件扩展:
    • Triton 支持通过插件的方式扩展其功能,满足不同安全研究人员的特定需求
    • 开发者可以编写自定义的分析插件,集成到 Triton 中使用
  4. 交互式分析:
    • Triton 提供了交互式的命令行界面,安全研究人员可以在此进行手工分析、调试等操作
    • 它支持设置断点、单步执行等调试功能
  5. 跨平台支持:
    • Triton 可以运行在 Windows、Linux、macOS 等多种操作系统上,为安全研究提供跨平台的分析能力

在 Triton 中执行的执行是由我们控制的,污点分析和符号执行都是基于模拟执行实现的

Triton 需要的依赖如下:

1
2
3
4
5
6
* libcapstone                >= 4.0.x   https://github.com/capstone-engine/capstone
* libboost (optional) >= 1.68
* libpython (optional) >= 3.6
* libz3 (optional) >= 4.6.0 https://github.com/Z3Prover/z3
* libbitwuzla (optional) >= 0.4.x https://github.com/bitwuzla/bitwuzla
* llvm (optional) >= 12
  • 如果编译生成的 /usr/lib/python3.8/site-packages/triton.so 出现段错误,则大概率是 libcapstone 的版本问题

使用 vcpkg 下载 Triton 以及依赖:

1
2
3
4
5
git clone https://github.com/Microsoft/vcpkg.git
cd vcpkg
./bootstrap-vcpkg.sh # ./bootstrap-vcpkg.bat for Windows
./vcpkg integrate install
./vcpkg install triton
  • Vcpkg(Visual C++ Package Manager)是一个由微软开发的命令行包管理器,用于C++语言
  • 它旨在简化在各种平台上获取和安装C++库的过程,特别是 Windows
  • Vcpkg 支持多种编译器和构建系统,并且可以与 Visual Studio 集成
1
2
3
./vcpkg list				# 列出已安装的库
./vcpkg search <keyword> # 搜索可用的库
./vcpkg update # 更新vcpkg

在安装 Triton 此之前需要先安装 z3(在 vcpkg 中可以找到):

1
2
3
4
python scripts/mk_make.py --prefix=/home/yhellow --python --pypkgdir=/home/yhellow/.local/lib/python3.8/site-packages/
cd build
make -j8
sudo make install

然后使用如下命令安装 Triton:

1
2
3
4
5
6
mkdir build
cd build
cmake ..
make -j8
sudo make install
sudo mv /usr/lib/python3.8/site-packages/triton.so /home/yhellow/.local/lib/python3.8/site-packages/

Triton 模拟执行

模拟执行(Emulation)是一种技术,它允许软件或硬件模拟另一系统的行为,以便在不同的环境中运行程序或执行任务

使用案例如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from __future__ import print_function
from triton import *

code = [ # 每一项的结构是 (指令的地址, 指令的字节码)
(0x400000, b"\x48\x8b\x05\xb8\x13\x00\x00"), # mov rax, QWORD PTR [rip+0x13b8]
(0x400007, b"\x48\x8d\x34\xc3"), # lea rsi, [rbx+rax*8]
(0x40000b, b"\x67\x48\x8D\x74\xC3\x0A"), # lea rsi, [ebx+eax*8+0xa]
(0x400011, b"\x66\x0F\xD7\xD1"), # pmovmskb edx, xmm1
(0x400015, b"\x89\xd0"), # mov eax, edx
(0x400017, b"\x80\xf4\x99"), # xor ah, 0x99
(0x40001a, b"\xC5\xFD\x6F\xCA"), # vmovdqa ymm1, ymm2
]

if __name__ == '__main__':
ctx = TritonContext()
ctx.setArchitecture(ARCH.X86_64) # 设置模拟执行的代码架构

for (addr, opcode) in code:
inst = Instruction() # 新建一个指令对象
inst.setOpcode(opcode) # 传递字节码
inst.setAddress(addr) # 传递指令的地址
ctx.processing(inst) # 执行指令

print(inst) # 打印指令的信息
print(' ---------------')
print(' Is memory read :', inst.isMemoryRead())
print(' Is memory write:', inst.isMemoryWrite())
print(' ---------------')
for op in inst.getOperands():
print(' Operand:', op)
if op.getType() == OPERAND.MEM:
print(' - segment :', op.getSegmentRegister())
print(' - base :', op.getBaseRegister())
print(' - index :', op.getIndexRegister())
print(' - scale :', op.getScale())
print(' - disp :', op.getDisplacement())
print(' ---------------')
print()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
0x400000: mov rax, qword ptr [rip + 0x13b8]
---------------
Is memory read : True
Is memory write: False
---------------
Operand: rax:64 bv[63..0]
---------------
Operand: [@0x4013bf]:64 bv[63..0]
- segment : unknown:1 bv[0..0]
- base : rip:64 bv[63..0]
- index : unknown:1 bv[0..0]
- scale : 0x1:64 bv[63..0]
- disp : 0x13b8:64 bv[63..0]
---------------

0x400007: lea rsi, [rbx + rax*8]
---------------
Is memory read : False
Is memory write: False
---------------
Operand: rsi:64 bv[63..0]
---------------
Operand: [@0x0]:64 bv[63..0]
- segment : unknown:1 bv[0..0]
- base : rbx:64 bv[63..0]
- index : rax:64 bv[63..0]
- scale : 0x8:64 bv[63..0]
- disp : 0x0:64 bv[63..0]
---------------

0x40000b: lea rsi, [ebx + eax*8 + 0xa]
---------------
Is memory read : False
Is memory write: False
---------------
Operand: rsi:64 bv[63..0]
---------------
Operand: [@0xa]:64 bv[63..0]
- segment : unknown:1 bv[0..0]
- base : ebx:32 bv[31..0]
- index : eax:32 bv[31..0]
- scale : 0x8:32 bv[31..0]
- disp : 0xa:32 bv[31..0]
---------------

0x400011: pmovmskb edx, xmm1
---------------
Is memory read : False
Is memory write: False
---------------
Operand: edx:32 bv[31..0]
---------------
Operand: xmm1:128 bv[127..0]
---------------

0x400015: mov eax, edx
---------------
Is memory read : False
Is memory write: False
---------------
Operand: eax:32 bv[31..0]
---------------
Operand: edx:32 bv[31..0]
---------------

0x400017: xor ah, 0x99
---------------
Is memory read : False
Is memory write: False
---------------
Operand: ah:8 bv[15..8]
---------------
Operand: 0x99:8 bv[7..0]
---------------

0x40001a: vmovdqa ymm1, ymm2
---------------
Is memory read : False
Is memory write: False
---------------
Operand: ymm1:256 bv[255..0]
---------------
Operand: ymm2:256 bv[255..0]
---------------

上述案例还没有展示模拟执行的精髓,下面这个案例将对一个二进制文件进行模拟执行:

1
gcc test.c -o test -static
  • 设置静态,因为模拟执行没法识别 .got.plt(某些 stdio.h 库中的函数会被编译器优化,进而在调用时可以直接 call,并不需要借助 .got.plt)
  • 另外设置静态可以关闭 PIE,不需要程序进行格外的偏移计算
1
2
3
4
5
6
7
8
9
10
11
#include<stdio.h>

int main(){
char buf[0x10];

read(0,buf,0x200);
int fd = open(buf,0);
read(fd ,buf, 0x10);
write(1 ,buf, 0x10);
return 0;
}
  • .got.plt 中的地址在磁盘上的二进制文件中可能是占位符,但在内存中会被实际的地址所替换(实际地址从 .rela.plt 中获取)
  • IDA 会自动分析这些地址,因此在 IDA 中看到的并不是磁盘上的数据

模拟执行的脚本如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from triton import *

def taint_analysis2(start, end):
ctx = TritonContext()
with open('./test', 'rb') as f:
bin1 = f.read()

ctx.setArchitecture(ARCH.X86_64)
ctx.setConcreteMemoryAreaValue(0x400000, bin1) # 将二进制文件加载到0x400000处(这个地址由IDA分析得出,是程序的基地址)
RBP_ADDR = 0x60000000
RSP_ADDR = RBP_ADDR - 0x20000000
INPUT_ADDR = 0x10000000

ctx.setConcreteRegisterValue(ctx.registers.rip, start) # 设置rip寄存器
ctx.setConcreteRegisterValue(ctx.registers.rsp, RSP_ADDR) # 设置rsp寄存器
ctx.setConcreteRegisterValue(ctx.registers.rbp, RBP_ADDR) # 设置rbp寄存器

input = b"yhellow\x00"
ctx.setConcreteMemoryAreaValue(INPUT_ADDR, input) # 将字符串加载到INPUT_ADDR
pc = start

while pc:
inst = Instruction()
opcode = ctx.getConcreteMemoryAreaValue(pc, 16) # 读取opcode
inst.setOpcode(opcode)
inst.setAddress(pc)
ctx.processing(inst)
print(str(inst))

pc = ctx.getConcreteRegisterValue(ctx.registers.rip) # 读取rip寄存器

if __name__ == '__main__':
taint_analysis2(0x401CD5, 0x401D74) # main的地址范围(在IDA中查看)

最终结果如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
0x401cd5: endbr64
0x401cd9: push rbp
0x401cda: mov rbp, rsp
0x401cdd: sub rsp, 0x30
0x401ce1: mov rax, qword ptr fs:[0x28]
0x401cea: mov qword ptr [rbp - 8], rax
0x401cee: xor eax, eax
0x401cf0: lea rax, [rbp - 0x20]
0x401cf4: mov edx, 0x200
0x401cf9: mov rsi, rax
0x401cfc: mov edi, 0
0x401d01: mov eax, 0
0x401d06: call 0x4483f0 /* 执行call时会跳转到对应的代码 */
0x4483f0: endbr64
0x4483f4: mov eax, dword ptr fs:[0x18]
0x4483fc: test eax, eax
0x4483fe: jne 0x448410
0x448400: syscall
0x448402: cmp rax, -0x1000
0x448408: ja 0x448460
0x44840a: ret /* 返回原函数 */
0x401d0b: lea rax, [rbp - 0x20]
0x401d0f: mov esi, 0
0x401d14: mov rdi, rax
0x401d17: mov eax, 0
0x401d1c: call 0x4482c0
0x4482c0: endbr64
0x4482c4: push r12
0x4482c6: mov r10d, esi
0x4482c9: mov r12d, esi
0x4482cc: push rbp
0x4482cd: mov rbp, rdi
0x4482d0: sub rsp, 0x68
0x4482d4: mov qword ptr [rsp + 0x40], rdx
0x4482d9: mov rax, qword ptr fs:[0x28]
0x4482e2: mov qword ptr [rsp + 0x28], rax
0x4482e7: xor eax, eax
0x4482e9: and r10d, 0x40
0x4482ed: jne 0x448348
0x4482ef: mov eax, esi
0x4482f1: and eax, 0x410000
0x4482f6: cmp eax, 0x410000
0x4482fb: je 0x448348
0x4482fd: mov eax, dword ptr fs:[0x18]
0x448305: test eax, eax
0x448307: jne 0x448370
0x448309: mov edx, r12d
0x44830c: mov rsi, rbp
0x44830f: mov edi, 0xffffff9c
0x448314: mov eax, 0x101
0x448319: syscall
0x44831b: cmp rax, -0x1000
0x448321: ja 0x4483b8
0x448327: mov rcx, qword ptr [rsp + 0x28]
0x44832c: xor rcx, qword ptr fs:[0x28]
0x448335: jne 0x4483e1
0x44833b: add rsp, 0x68
0x44833f: pop rbp
0x448340: pop r12
0x448342: ret
0x401d21: mov dword ptr [rbp - 0x24], eax
0x401d24: lea rcx, [rbp - 0x20]
0x401d28: mov eax, dword ptr [rbp - 0x24]
0x401d2b: mov edx, 0x10
0x401d30: mov rsi, rcx
0x401d33: mov edi, eax
0x401d35: mov eax, 0
0x401d3a: call 0x4483f0
0x4483f0: endbr64
0x4483f4: mov eax, dword ptr fs:[0x18]
0x4483fc: test eax, eax
0x4483fe: jne 0x448410
0x448400: syscall
0x448402: cmp rax, -0x1000
0x448408: ja 0x448460
0x44840a: ret
0x401d3f: lea rax, [rbp - 0x20]
0x401d43: mov edx, 0x10
0x401d48: mov rsi, rax
0x401d4b: mov edi, 1
0x401d50: mov eax, 0
0x401d55: call 0x448490
0x448490: endbr64
0x448494: mov eax, dword ptr fs:[0x18]
0x44849c: test eax, eax
0x44849e: jne 0x4484b0
0x4484a0: mov eax, 1
0x4484a5: syscall
0x4484a7: cmp rax, -0x1000
0x4484ad: ja 0x448500
0x4484af: ret
0x401d5a: mov eax, 0
0x401d5f: mov rdx, qword ptr [rbp - 8]
0x401d63: sub rdx, qword ptr fs:[0x28]
0x401d6c: je 0x401d73
0x401d73: leave
0x401d74: ret

Triton 污点分析

污点分析(Taint Analysis)是一种计算机安全分析技术,用于追踪数据在程序中的流动情况,特别是那些可能来源于不信任的输入的数据,这种技术可以帮助识别和预防安全漏洞,如跨站脚本(XSS)、SQL注入、命令注入等

污点分析的脚本如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from triton import *

def taint_analysis2(start, end):
ctx = TritonContext()
with open('./test', 'rb') as f:
bin1 = f.read()

ctx.setArchitecture(ARCH.X86_64)
ctx.setConcreteMemoryAreaValue(0x400000, bin1)
RBP_ADDR = 0x7ffffffde000
RSP_ADDR = RBP_ADDR - 0x21000
INPUT_ADDR = 0x4C2290

ctx.setConcreteRegisterValue(ctx.registers.rip, start)
ctx.setConcreteRegisterValue(ctx.registers.rsp, RSP_ADDR)
ctx.setConcreteRegisterValue(ctx.registers.rbp, RBP_ADDR)

input = b"./flag\x00"
ctx.setConcreteMemoryAreaValue(INPUT_ADDR, input)

pc = start
nop_addrs = []

while pc:
inst = Instruction()
opcode = ctx.getConcreteMemoryAreaValue(pc, 16)

inst.setOpcode(opcode)
inst.setAddress(pc)
ctx.processing(inst)

print(str(inst))
if pc == 0x401cf5:
print("--------------")
print("taint target: "+hex(ctx.getConcreteRegisterValue(ctx.registers.rdi))) # 获取寄存器的数据
print("--------------")
ctx.taintRegister(ctx.registers.rdi) # 将rdi中的地址数据标记为污点

if inst.isTainted(): # 检测该指令是否被污染
nop_addrs.append(hex(pc))
print("--------------")
if inst.isMemoryRead():
for op in inst.getOperands():
if op.getType() == OPERAND.MEM:
print("read:0x{:08x}, size:{}".format(
op.getAddress(), op.getSize()))

if inst.isMemoryWrite():
for op in inst.getOperands():
if op.getType() == OPERAND.MEM:
print("write:0x{:08x}, size:{}".format(
op.getAddress(), op.getSize()))

pc = ctx.getConcreteRegisterValue(ctx.registers.rip)

print(nop_addrs)

if __name__ == '__main__':
taint_analysis2(0x401CD5, 0x401D3F)

最终结果如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
0x401cd5: endbr64
0x401cd9: push rbp
0x401cda: mov rbp, rsp
0x401cdd: sub rsp, 0x10
0x401ce1: mov esi, 0
0x401ce6: lea rax, [rip + 0xc05a3]
0x401ced: mov rdi, rax
0x401cf0: mov eax, 0
0x401cf5: call 0x448280
--------------
taint target: 0x4c2290
--------------
0x448280: endbr64
0x448284: push r12
0x448286: mov r10d, esi
0x448289: mov r12d, esi
0x44828c: push rbp
0x44828d: mov rbp, rdi
--------------
0x448290: sub rsp, 0x68
0x448294: mov qword ptr [rsp + 0x40], rdx
0x448299: mov rax, qword ptr fs:[0x28]
0x4482a2: mov qword ptr [rsp + 0x28], rax
0x4482a7: xor eax, eax
0x4482a9: and r10d, 0x40
0x4482ad: jne 0x448308
0x4482af: mov eax, esi
0x4482b1: and eax, 0x410000
0x4482b6: cmp eax, 0x410000
0x4482bb: je 0x448308
0x4482bd: mov eax, dword ptr fs:[0x18]
0x4482c5: test eax, eax
0x4482c7: jne 0x448330
0x4482c9: mov edx, r12d
0x4482cc: mov rsi, rbp
--------------
0x4482cf: mov edi, 0xffffff9c
0x4482d4: mov eax, 0x101
0x4482d9: syscall
0x4482db: cmp rax, -0x1000
0x4482e1: ja 0x448378
0x4482e7: mov rcx, qword ptr [rsp + 0x28]
0x4482ec: xor rcx, qword ptr fs:[0x28]
0x4482f5: jne 0x4483a1
0x4482fb: add rsp, 0x68
0x4482ff: pop rbp
0x448300: pop r12
0x448302: ret
0x401cfa: mov dword ptr [rbp - 4], eax
0x401cfd: mov eax, dword ptr [rbp - 4]
0x401d00: mov edx, 0x10
0x401d05: lea rcx, [rip + 0xc0584]
0x401d0c: mov rsi, rcx
0x401d0f: mov edi, eax
0x401d11: mov eax, 0
0x401d16: call 0x4483b0
0x4483b0: endbr64
0x4483b4: mov eax, dword ptr fs:[0x18]
0x4483bc: test eax, eax
0x4483be: jne 0x4483d0
0x4483c0: syscall
0x4483c2: cmp rax, -0x1000
0x4483c8: ja 0x448420
0x4483ca: ret
0x401d1b: mov edx, 0x10
0x401d20: lea rax, [rip + 0xc0569]
0x401d27: mov rsi, rax
0x401d2a: mov edi, 1
0x401d2f: mov eax, 0
0x401d34: call 0x448450
0x448450: endbr64
0x448454: mov eax, dword ptr fs:[0x18]
0x44845c: test eax, eax
0x44845e: jne 0x448470
0x448460: mov eax, 1
0x448465: syscall
0x448467: cmp rax, -0x1000
0x44846d: ja 0x4484c0
0x44846f: ret
0x401d39: mov eax, 0
0x401d3e: leave
0x401d3f: ret
['0x44828d', '0x4482cc']
  • 在 GDB 中可以查看这两处地址的数据:
1
2
3
4
5
6
7
8
	......
*RDI 0x4c2290 (buf) ◂— 0x0
......
*RBP 0x7fffffffdb40 —▸ 0x402d20 (__libc_csu_init) ◂— endbr64
*RSP 0x7fffffffdb18 —▸ 0x7fffffffdb40 —▸ 0x402d20 (__libc_csu_init) ◂— endbr64
*RIP 0x44828d (open64+13) ◂— mov rbp, rdi
───────────────────────────────────[ DISASM ]───────────────────────────────────
0x44828d <open64+13> mov rbp, rdi <buf>
1
2
3
4
5
6
7
8
	......
RSI 0x0
......
*RBP 0x4c2290 (buf) ◂— 0x0
*RSP 0x7fffffffdab0 ◂— 0x1b
*RIP 0x4482cc (open64+76) ◂— mov rsi, rbp
───────────────────────────────────[ DISASM ]───────────────────────────────────
0x4482cc <open64+76> mov rsi, rbp <buf>
  • 可以发现两处被污染的指令都是引用了 0x4c2290 这处数据,但这并不意味着只要有 0x4c2290 就会被污染
  • 可以分析如下数据:
1
2
3
4
5
6
7
8
9
10
	......
*RCX 0x4c2290 (buf) ◂— 0x0
......
RSI 0x4c2290 (buf) ◂— 0x0
......
RBP 0x7fffffffdb40 —▸ 0x402d20 (__libc_csu_init) ◂— endbr64
RSP 0x7fffffffdb30 ◂— 0x0
*RIP 0x401d0c (main+55) ◂— mov rsi, rcx
───────────────────────────────────[ DISASM ]───────────────────────────────────
0x401d0c <main+55> mov rsi, rcx
  • 虽然也是引用了 0x4c2290,但这个 0x4c2290 的来源不同:
    • 被污染指令的 0x4c2290 源自于 0x401ce6: lea rax, [rip + 0xc05a3]
    • 而这里的 0x4c2290 源自于 0x401d05: lea rcx, [rip + 0xc0584]

Triton 代码插桩

为了解决 Triton 模拟执行没法识别 .got.plt 的问题,可以使用 Triton 自带的插桩模块将 .got.plt 中的函数给替换为自定义的函数

1
gcc test.c -o test
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#include<stdio.h>
#include<stdlib.h>

int main(){
char* buf = malloc(0x10);

read(0,buf,0x200);
printf(buf);

int fd = open(buf,0);
read(fd ,buf, 0x10);
write(1 ,buf, 0x10);

free(buf);
free(buf);
return 0;
}
  • 可以发现堆溢出,格式化字符串,double free 等多个漏洞

代码插桩的脚本如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from triton import *
import string
import os
import time
import lief

BASE_GOT = 0x10000000
BASE_ARGV = 0x20000000
BASE_STACK = 0x7ffffffde000

heap_addr = 0x555555559000
heap_size = 0
is_free = False

def hookingHandler(ctx):
pc = ctx.getConcreteRegisterValue(ctx.registers.rip)
for rel in customRelocation:
if rel[2] == pc:
ret_value = rel[1](ctx) # 调用hook函数,并获取其返回值
if ret_value is not None:
ctx.setConcreteRegisterValue(ctx.registers.rax, ret_value) # 返回值将会被设置到ctx.registers.rax中
ret_addr = ctx.getConcreteMemoryValue(MemoryAccess(ctx.getConcreteRegisterValue(ctx.registers.rsp), CPUSIZE.QWORD)) # 返回地址位于栈顶
ctx.setConcreteRegisterValue(ctx.registers.rip, ret_addr) # 设置rip
ctx.setConcreteRegisterValue(ctx.registers.rsp, ctx.getConcreteRegisterValue(ctx.registers.rsp)+CPUSIZE.QWORD) # 调整rsp
return

def mymalloc(ctx):
global heap_size
print('[+] malloc hooked')
heap_size = ctx.getConcreteRegisterValue(ctx.registers.rdi)
return heap_addr

def myfree(ctx):
global is_free
print('[+] free hooked')
if is_free:
print('[-] free BUG')
is_free = True
return 0

def mylibc(ctx):
print('[+] __libc_start_main hooked')
return 0

def myopen(ctx):
print("[+] open hooked")
arg1 = ctx.getConcreteRegisterValue(ctx.registers.rdi)
arg2 = ctx.getConcreteRegisterValue(ctx.registers.rsi)
return 0

def myread(ctx):
global heap_size
print("[+] read hooked")
read_size = ctx.getConcreteRegisterValue(ctx.registers.rdx)
if read_size > heap_size:
print("[-] read BUG")
return 0

def mywrite(ctx):
global heap_size
print("[+] write hooked")
read_size = ctx.getConcreteRegisterValue(ctx.registers.rdx)
if read_size > heap_size:
print("[-] write BUG")
return 0

def myprintf(ctx):
print("[+] printf hooked")
return 0

customRelocation = [
['__libc_start_main', mylibc, None],
['malloc', mymalloc, None],
['free', myfree, None],
['open', myopen, None],
['read', myread, None],
['write', mywrite, None],
['printf', myprintf, None],
]

def makeRelocation(ctx, binary):
for pltIndex in range(len(customRelocation)):
customRelocation[pltIndex][2] = BASE_GOT + pltIndex # 设置自定义got地址

relocations = [x for x in binary.pltgot_relocations] # 读取pltgot重定位符号
relocations.extend([x for x in binary.dynamic_relocations]) # 读取dynamic重定位符号

for rel in relocations:
symbolName = rel.symbol.name
symbolRelo = rel.address
for crel in customRelocation:
if symbolName == crel[0]:
print('Hooking %-10s:0x%x' %(symbolName,symbolRelo))
ctx.setConcreteMemoryValue(MemoryAccess(symbolRelo, CPUSIZE.QWORD), crel[2]) # 将模拟执行的got表修改为基于BASE_GOT的自定义got表
break
return

def loadBinary(ctx, binary):
phdrs = binary.segments # 读取所有segment
for phdr in phdrs:
size = phdr.physical_size
vaddr = phdr.virtual_address
print('[+] Loading 0x%06x - 0x%06x' %(vaddr, vaddr+size))
ctx.setConcreteMemoryAreaValue(vaddr, list(phdr.content))
return


if __name__ == '__main__':
ctx = TritonContext(ARCH.X86_64)
ctx.setMode(MODE.ALIGNED_MEMORY, True)
ctx.setMode(MODE.CONSTANT_FOLDING, True)

binary = lief.ELF.parse("./test")

loadBinary(ctx, binary)
makeRelocation(ctx, binary)
ctx.setConcreteRegisterValue(ctx.registers.rbp, BASE_STACK)
ctx.setConcreteRegisterValue(ctx.registers.rsp, BASE_STACK)

pc = 0x11E9
while pc:
inst = Instruction()
opcode = ctx.getConcreteMemoryAreaValue(pc, 16)

inst.setOpcode(opcode)
inst.setAddress(pc)
ctx.processing(inst)
print(str(inst))

hookingHandler(ctx) # 处理hook
pc = ctx.getConcreteRegisterValue(ctx.registers.rip)

if pc == 0x11ff:
print("[+] rax: "+hex(ctx.getConcreteRegisterValue(ctx.registers.rax)))

最终结果如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
[+] Loading 0x000040 - 0x000318
[+] Loading 0x000318 - 0x000334
[+] Loading 0x000000 - 0x000710
[+] Loading 0x001000 - 0x001325
[+] Loading 0x002000 - 0x002150
[+] Loading 0x003d90 - 0x004010
[+] Loading 0x003da0 - 0x003f90
[+] Loading 0x000338 - 0x000358
[+] Loading 0x000358 - 0x00039c
[+] Loading 0x000338 - 0x000358
[+] Loading 0x002004 - 0x002048
[+] Loading 0x000000 - 0x000000
[+] Loading 0x003d90 - 0x004000
Hooking free
Hooking write
Hooking printf
Hooking read
Hooking malloc
Hooking open
Hooking __libc_start_main
0x11e9: endbr64
0x11ed: push rbp
0x11ee: mov rbp, rsp
0x11f1: sub rsp, 0x10
0x11f5: mov edi, 0x10
0x11fa: call 0x10e0
0x10e0: endbr64
0x10e4: bnd jmp qword ptr [rip + 0x2edd]
[+] malloc hooked
[+] rax: 0x555555559000
0x11ff: mov qword ptr [rbp - 8], rax
0x1203: mov rax, qword ptr [rbp - 8]
0x1207: mov edx, 0x200
0x120c: mov rsi, rax
0x120f: mov edi, 0
0x1214: mov eax, 0
0x1219: call 0x10d0
0x10d0: endbr64
0x10d4: bnd jmp qword ptr [rip + 0x2ee5]
[+] read hooked
[-] read BUG
0x121e: mov rax, qword ptr [rbp - 8]
0x1222: mov rdi, rax
0x1225: mov eax, 0
0x122a: call 0x10c0
0x10c0: endbr64
0x10c4: bnd jmp qword ptr [rip + 0x2eed]
[+] printf hooked
0x122f: mov rax, qword ptr [rbp - 8]
0x1233: mov esi, 0
0x1238: mov rdi, rax
0x123b: mov eax, 0
0x1240: call 0x10f0
0x10f0: endbr64
0x10f4: bnd jmp qword ptr [rip + 0x2ed5]
[+] open hooked
0x1245: mov dword ptr [rbp - 0xc], eax
0x1248: mov rcx, qword ptr [rbp - 8]
0x124c: mov eax, dword ptr [rbp - 0xc]
0x124f: mov edx, 0x10
0x1254: mov rsi, rcx
0x1257: mov edi, eax
0x1259: mov eax, 0
0x125e: call 0x10d0
0x10d0: endbr64
0x10d4: bnd jmp qword ptr [rip + 0x2ee5]
[+] read hooked
0x1263: mov rax, qword ptr [rbp - 8]
0x1267: mov edx, 0x10
0x126c: mov rsi, rax
0x126f: mov edi, 1
0x1274: mov eax, 0
0x1279: call 0x10b0
0x10b0: endbr64
0x10b4: bnd jmp qword ptr [rip + 0x2ef5]
[+] write hooked
0x127e: mov rax, qword ptr [rbp - 8]
0x1282: mov rdi, rax
0x1285: call 0x10a0
0x10a0: endbr64
0x10a4: bnd jmp qword ptr [rip + 0x2efd]
[+] free hooked
0x128a: mov rax, qword ptr [rbp - 8]
0x128e: mov rdi, rax
0x1291: call 0x10a0
0x10a0: endbr64
0x10a4: bnd jmp qword ptr [rip + 0x2efd]
[+] free hooked
[-] free BUG
0x1296: mov eax, 0
0x129b: leave
0x129c: ret

Triton 符号执行

符号执行(Symbolic Execution)是一种程序分析技术,它使用符号值代替具体数值来探索程序的所有可能执行路径,这种方法可以帮助分析者或自动化工具理解程序的行为,发现潜在的错误或安全漏洞

符号执行的脚本如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from triton import *
import lief
import os
import time

def AnalysisBinary(path):
binary = lief.ELF.parse(path) # 使用lief加载二进制文件

sections = binary.sections
for section in sections: # 读取所有section
name = section.name
size = section.size
vaddr = section.virtual_address
if name != "":
print('[+] %-28s 0x%06x - 0x%06x' % (name+":", vaddr, vaddr+size))
data = bytes(section.content[0:0x20])
for i in range(len(data)):
print(f"{data[i]:02x}", end=" ")
if i % 0x10 == 0xf: print()
if i % 0x10 != 0xf: print()
print("-------------------------------------------------")
ctx.setConcreteMemoryAreaValue(vaddr, bytes(section.content)) # 加载所有section

symbols = binary.symbols # 读取所有symbol
for symbol in symbols:
if symbol.name == 'main':
main_address = symbol.value
print()
print('[+] Address of main function: 0x{:x}'.format(main_address))
print()
break
else:
print("'main' function not found")

if __name__ == '__main__':
ctx = TritonContext()
ctx.setArchitecture(ARCH.X86_64)
ast = ctx.getAstContext()

AnalysisBinary(os.path.join(os.path.dirname(__file__), './test'))

start = 0x401DE0 # 执行关键匹配函数之前的某个地址(从IDA中得出)
RBP_ADDR = 0x7ffffffde000
RSP_ADDR = RBP_ADDR - 0x21000

ctx.setAstRepresentationMode(AST_REPRESENTATION.PYTHON)

ctx.setConcreteRegisterValue(ctx.registers.rip, start)
ctx.setConcreteRegisterValue(ctx.registers.rsp, RSP_ADDR)
ctx.setConcreteRegisterValue(ctx.registers.rbp, RBP_ADDR)

for i in range(5):
input_addr = ctx.getConcreteRegisterValue(ctx.registers.rbp) - 0x20
ctx.setConcreteMemoryValue(MemoryAccess(input_addr + i, CPUSIZE.BYTE), ord('a'))
ctx.symbolizeMemory(MemoryAccess(input_addr + i, CPUSIZE.BYTE)) # 将目标地址的数据设置为符号变量

pc = start

while pc:
inst = Instruction()
opcode = ctx.getConcreteMemoryAreaValue(pc, 16)

inst.setOpcode(opcode)
inst.setAddress(pc)
ctx.processing(inst)

print(str(inst))
if inst.getAddress() == 0x401DEC:
rdata = ctx.getRegisterAst(ctx.registers.rax) # 获取rax寄存器的AST结构,可以添加到后续的约束条件中

sv1 = ast.variable(ctx.getSymbolicVariable(0)) # 读取设置的符号变量
sv2 = ast.variable(ctx.getSymbolicVariable(1))
sv3 = ast.variable(ctx.getSymbolicVariable(2))
sv4 = ast.variable(ctx.getSymbolicVariable(3))
sv5 = ast.variable(ctx.getSymbolicVariable(4))

cstr = ast.land([rdata == 0xAD6D] # 添加约束条件
+ [sv1 >= ord(b'A'), sv1 <= ord(b'z')]
+ [sv2 >= ord(b'A'), sv2 <= ord(b'z')]
+ [sv3 >= ord(b'A'), sv3 <= ord(b'z')]
+ [sv4 >= ord(b'A'), sv4 <= ord(b'z')]
+ [sv5 >= ord(b'A'), sv5 <= ord(b'z')]
)

model = ctx.getModel(cstr) # 进行约束求解
answer = ""
for k, v in sorted(model.items()):
value = v.getValue()
answer += chr(value)

if len(answer)==5:
print("answer: {}".format(answer))

break

pc = ctx.getConcreteRegisterValue(ctx.registers.rip)

测试脚本如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
#include <stdio.h>
#include <stdlib.h>

char *serial = "\x31\x3e\x3d\x26\x31";

int check(char *ptr) {
int i;
int hash = 0xABCD;

for (i = 0; ptr[i]; i++)
hash += ptr[i] ^ serial[i % 5];

return hash;
}

int main() {
int ret;
char buf[0x10];
read(0, buf, 5); // TyrbP

ret = check(buf);
if (ret == 0xad6d)
printf("Win\n");
else
printf("fail\n");

return 0;
}

先用 IDA 分析程序起始地址 0x401DE0:

1
2
3
4
5
6
7
.text:0000000000401DE0 48 8D 45 E0                   lea     rax, [rbp+var_20]
.text:0000000000401DE4 48 89 C7 mov rdi, rax
.text:0000000000401DE7 E8 39 FF FF FF call check
.text:0000000000401DE7
.text:0000000000401DEC 89 45 DC mov [rbp+var_24], eax
.text:0000000000401DEF 81 7D DC 6D AD 00 00 cmp [rbp+var_24], 0AD6Dh
.text:0000000000401DF6 75 11 jnz short loc_401E09
  • 可以发现只要 check 函数的返回值为 0xAD6D 就可以输出 Win,利用这一点可以添加约束

Triton 解决 CTF 问题

题目为:alexctf-2017-re2-cpp-is-awesome

核心加密逻辑如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
i = 0;
flag[0] = std::string::begin(string);
while ( 1 )
{
v13 = std::string::end(string);
if ( !sub_400D3D((__int64)flag, (__int64)&v13) )
break;
v8 = *(unsigned __int8 *)get_char((__int64)flag);
if ( (_BYTE)v8 != key[index[i]] )
fail((__int64)flag, (__int64)&v13, v8);
++i;
sub_400D7A(flag);
}
  • 其实就是简单的换位

本题目需要使用符号执行,首先在 IDA 中分析决定程序流程的代码片段:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
.text:0000000000400C4B 48 8D 45 A0                   lea     rax, [rbp+flag]
.text:0000000000400C4F 48 89 C7 mov rdi, rax
.text:0000000000400C52 E8 43 01 00 00 call get_char
.text:0000000000400C52
.text:0000000000400C57 0F B6 10 movzx edx, byte ptr [rax]
.text:0000000000400C5A 48 8B 0D 3F 14 20 00 mov rcx, cs:key
.text:0000000000400C61 8B 45 EC mov eax, [rbp+i]
.text:0000000000400C64 48 98 cdqe
.text:0000000000400C66 8B 04 85 C0 20 60 00 mov eax, index[rax*4]
.text:0000000000400C6D 48 98 cdqe
.text:0000000000400C6F 48 01 C8 add rax, rcx
.text:0000000000400C72 0F B6 00 movzx eax, byte ptr [rax]
.text:0000000000400C75 38 C2 cmp dl, al
.text:0000000000400C77 0F 95 C0 setnz al
.text:0000000000400C7A 84 C0 test al, al
.text:0000000000400C7C 74 05 jz short loc_400C83
.text:0000000000400C7C
.text:0000000000400C7E ; try {
.text:0000000000400C7E E8 D3 FE FF FF call fail
  • lea rax, [rbp+flag] 加载 flag 地址,到 test al, al 判断是否继续循环,程序的核心约束条件就是要使 zf == 1(指令 test al, al 返回“0”,也就是 cmp dl, al 返回“0”)

基于约束条件,解题脚本如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from triton import *
import lief
import sys

def AnalysisBinary(path):
binary = lief.ELF.parse(path)
sections = binary.sections
for section in sections:
name = section.name
size = section.size
vaddr = section.virtual_address
if name != "":
print('[+] %-28s 0x%06x - 0x%06x' % (name+":", vaddr, vaddr+size))
data = bytes(section.content[0:0x20])
for i in range(len(data)):
print(f"{data[i]:02x}", end=" ")
if i % 0x10 == 0xf: print()
if i % 0x10 != 0xf: print()
print("-------------------------------------------------")
ctx.setConcreteMemoryAreaValue(vaddr, bytes(section.content))

if __name__ == '__main__':
ctx = TritonContext(ARCH.X86_64)
ctx.setMode(MODE.ALIGNED_MEMORY, True)
ctx.setMode(MODE.ONLY_ON_SYMBOLIZED, True)

AnalysisBinary("./test")

start = 0x400c4b
end = 0x400C7C
input = 0x002000 # 输入字符串的地址
for i in range(31):
ctx.setConcreteMemoryValue(MemoryAccess(input + i, CPUSIZE.BYTE), 0x61)
ctx.symbolizeMemory(MemoryAccess(input + i, CPUSIZE.BYTE)) # 将输入字符串设置为符号变量

rbp = 0x7fffffffe460
ctx.setConcreteRegisterValue(ctx.registers.rbp, rbp)
ctx.setConcreteRegisterValue(ctx.registers.rip, start)
ctx.setConcreteMemoryValue(MemoryAccess(rbp - 96, CPUSIZE.QWORD), input) # 初始化"rbp - 96"为输入字符串的地址
ctx.setConcreteMemoryValue(MemoryAccess(rbp - 20, CPUSIZE.DWORD), 0) # 初始化"rbp - 20"为"0"(这里是解密输入值时使用的索引)

for count in range(31):
pc = start

while pc: # 进行模拟执行,直到遇到约束条件
inst = Instruction()
opcode = ctx.getConcreteMemoryAreaValue(pc, 16)
inst.setOpcode(opcode)
inst.setAddress(pc)
ctx.processing(inst)
pc = ctx.getConcreteRegisterValue(ctx.registers.rip)
print(inst)
if pc == end:
print("------------------------------------------")
break

zf = ctx.getRegisterAst(ctx.registers.zf)
ctx.pushPathConstraint(zf == 1) # 添加约束条件

ctx.setConcreteMemoryValue(MemoryAccess(rbp - 20, CPUSIZE.DWORD), count + 1) # 更新索引
ctx.setConcreteMemoryValue(MemoryAccess(rbp - 96, CPUSIZE.DWORD), input + count + 1) # 更新输入值

mod = ctx.getModel(ctx.getPathPredicate()) # 进行约束求解
if not mod:
print('[-] Failed')
sys.exit(-1)

flag = ""
for k, v in sorted(mod.items()):
ctx.setConcreteVariableValue(ctx.getSymbolicVariable(v.getId()), v.getValue())
value = v.getValue()
flag += chr(value)

print("flag: {}".format(flag))
  • 这种加密的加密逻辑与输入值无关,与输入的顺序也无关,因此可以单独为每一位添加约束条件

exrop 的安装与使用

安装 exrop:

1
git clone https://github.com/d4em0n/exrop
  • 最后在 /home/yhellow 目录的 .zshenv 中写入 export PYTHONPATH=/home/yhellow/Tools/exrop:$PYTHONPATH 即可

第一次使用 exrop 可能会遇到如下报错:

1
TypeError: Z3ToTriton::visit(): 'SymVar_0' AST node not supported yet

经过多方查证,这个报错源自于 Z3_OP_UNINTERPRETED 变量,其在 C++ 绑定和 Python 绑定之间具有不同的值:

这就导致了 Triton 将 Z3_OP_UNINTERPRETED 给识别为 Z3_OP_RECURSIVE,为了解决这个 BUG 我选择修改 Triton 源码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
case Z3_OP_RECURSIVE: /* 添加对Z3_OP_RECURSIVE的识别 */
case Z3_OP_UNINTERPRETED: {
std::string name = function.name().str();

node = this->astCtxt->getVariableNode(name);
if (node == nullptr) {
node = this->astCtxt->string(name);
}

break;
}

default:
throw triton::exceptions::AstLifting("Z3ToTriton::visit(): '" + function.name().str() + "' AST node not supported yet");
  • 原本 Triton 不会对 Z3_OP_RECURSIVE 进行处理

exrop 的主要功能是自动生成 ROP 链

exrop 的使用案例如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from pwn import *
import time
from Exrop import Exrop

binname = "/lib/x86_64-linux-gnu/libc.so.6"
libc = ELF(binname, checksec=False)
binsh = next(libc.search(b"/bin/sh"))
open_libc = libc.symbols['open']
read_libc = libc.symbols['read']
write_libc = libc.symbols['write']
bss = libc.bss()

t = time.mktime(time.gmtime())
rop = Exrop(binname)
rop.find_gadgets(cache=True)

print("execve('/bin/sh', 0, 0)")
chain = rop.syscall(59, ("/bin/sh", 0, 0), bss) # 字符串会写到第3个参数上
chain.set_base_addr(0)
chain.dump()

print("execve('/bin/sh', 0, 0)")
chain = rop.syscall(59, (binsh, 0, 0)) # 如果没设置第3个参数,则不能直接写入字符串
chain.set_base_addr(0) # 设置基地址
print(chain.payload_str()) # 将ROP链转化为bytes字符串

print("open('/etc/passwd', 0)")
chain = rop.func_call(open_libc, ("/etc/passwd", 0), bss)
chain.set_base_addr(0)
chain.dump()

print("read('rax', bss, 0x100)")
chain = rop.func_call(read_libc, ('rax', bss, 0x100)) # 可以直接使用寄存器
chain.set_base_addr(0)
chain.dump()

print("write(1, bss, 0x100)")
chain = rop.func_call(write_libc, (1, bss, 0x100))
chain.set_base_addr(0)
chain.dump()

print("done in {}s".format(time.mktime(time.gmtime()) - t))
  • 结果如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
execve('/bin/sh', 0, 0)
$RSP+0x0000 : 0x0000000000036174 # pop rax ; ret
$RSP+0x0008 : 0x00000000001ed7a0
$RSP+0x0010 : 0x0000000000023b6a # pop rdi ; ret
$RSP+0x0018 : 0x0068732f6e69622f
$RSP+0x0020 : 0x000000000009a0cf # mov qword ptr [rax], rdi ; ret
$RSP+0x0028 : 0x0000000000036174 # pop rax ; ret
$RSP+0x0030 : 0x000000000000003b
$RSP+0x0038 : 0x0000000000023b6a # pop rdi ; ret
$RSP+0x0040 : 0x00000000001ed7a0
$RSP+0x0048 : 0x000000000002601f # pop rsi ; ret
$RSP+0x0050 : 0x0000000000000000
$RSP+0x0058 : 0x000000000015fae6 # pop rdx ; pop rbx ; ret
$RSP+0x0060 : 0x0000000000000000
$RSP+0x0068 : 0x0000000000000000
$RSP+0x0070 : 0x00000000000630a9 # syscall ; ret

execve('/bin/sh', 0, 0)
b'ta\x03\x00\x00\x00\x00\x00;\x00\x00\x00\x00\x00\x00\x00j;\x02\x00\x00\x00\x00\x00\xbdE\x1b\x00\x00\x00\x00\x00\x1f`\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xe6\xfa\x15\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xa90\x06\x00\x00\x00\x00\x00'
open('/etc/passwd', 0)
$RSP+0x0000 : 0x0000000000036174 # pop rax ; ret
$RSP+0x0008 : 0x00000000001ed7a0
$RSP+0x0010 : 0x0000000000023b6a # pop rdi ; ret
$RSP+0x0018 : 0x7361702f6374652f
$RSP+0x0020 : 0x000000000009a0cf # mov qword ptr [rax], rdi ; ret
$RSP+0x0028 : 0x0000000000036174 # pop rax ; ret
$RSP+0x0030 : 0x00000000001ed7a8
$RSP+0x0038 : 0x0000000000023b6a # pop rdi ; ret
$RSP+0x0040 : 0x0000000000647773
$RSP+0x0048 : 0x000000000009a0cf # mov qword ptr [rax], rdi ; ret
$RSP+0x0050 : 0x0000000000023b6a # pop rdi ; ret
$RSP+0x0058 : 0x00000000001ed7a0
$RSP+0x0060 : 0x000000000002601f # pop rsi ; ret
$RSP+0x0068 : 0x0000000000000000
$RSP+0x0070 : 0x000000000010df00

read('rax', bss, 0x100)
$RSP+0x0000 : 0x000000000015fae6 # pop rdx ; pop rbx ; ret
$RSP+0x0008 : 0x0000000000000000
$RSP+0x0010 : 0x0000000000023b6a
$RSP+0x0018 : 0x0000000000044808 # mov r13, rax ; mov rdi, r12 ; call rbx: next -> (0x00023b6a) # pop rdi ; ret
$RSP+0x0020 : 0x0000000000036174 # pop rax ; ret
$RSP+0x0028 : 0x000000000002601f
$RSP+0x0030 : 0x0000000000045872 # mov rdi, r13 ; call rax: next -> (0x0002601f) # pop rsi ; ret
$RSP+0x0038 : 0x000000000002601f # pop rsi ; ret
$RSP+0x0040 : 0x00000000001ed7a0
$RSP+0x0048 : 0x000000000015fae6 # pop rdx ; pop rbx ; ret
$RSP+0x0050 : 0x0000000000000100
$RSP+0x0058 : 0x0000000000000000
$RSP+0x0060 : 0x000000000010e1e0

write(1, bss, 0x100)
$RSP+0x0000 : 0x0000000000023b6a # pop rdi ; ret
$RSP+0x0008 : 0x0000000000000001
$RSP+0x0010 : 0x000000000002601f # pop rsi ; ret
$RSP+0x0018 : 0x00000000001ed7a0
$RSP+0x0020 : 0x000000000015fae6 # pop rdx ; pop rbx ; ret
$RSP+0x0028 : 0x0000000000000100
$RSP+0x0030 : 0x0000000000000000
$RSP+0x0038 : 0x000000000010e280

done in 2.0s

在不使用缓存的情况下 rop.find_gadgets 将会执行非常久的时间(特别是处理 libc.so.6),下面方法可以提高其运行速度:

1
2
3
4
5
6
7
8
9
10
11
def parseRopGadget(filename, opt=""):
from subprocess import Popen, PIPE, STDOUT
import re

cmd = ['ROPgadget', '--binary', filename, '--multibr', '--only',
'pop|xchg|add|sub|xor|mov|ret|jmp|call|syscall|leave', '--dump']
if opt:
cmd.append(opt)
process = Popen(cmd, stdout=PIPE, stderr=STDOUT)
stdout, _ = process.communicate()
......
  • 在 Exrop.py 的 parseRopGadget 函数中会使用 ROPgadget,减少不必要的查找指令可以大幅度提高运行速度
1
2
cmd = ['ROPgadget', '--binary', filename, '--multibr', '--only',
'pop|mov|ret|call|syscall', '--dump']
1
done in 28.0s

qsynthesis 的安装与使用

1
2
3
git clone https://github.com/quarkslab/qsynthesis.git
cd qsynthesis
pip3 install '.[all]'

QSynthesis 是一个 Python3 API,用于执行基于 I/O 的程序合成的 bitvector 表达式,旨在促进代码反混淆

  • 该算法是灰盒方法,结合了基于黑盒 I/O 的算法合成和白盒 AST 搜索以合成子表达式

核心合成基于 Triton 符号引擎,在此基础上构建整个框架,它提供以下功能:

  • 位向量表达式的合成
  • 能够通过 SMT 检查合成表达式的语义等价性
  • 能够合成常量(如果表达式编码常量)
  • 通过学习机制加班改进预言机(预计算表)的能力
  • 能够将合成表达重新组合回组装
  • 能够通过 REST API 提供预言机以方便合成使用
  • 一个 IDA 插件,提供合成的集成

使用工具生成表:

1
qsynthesis-table-manager generate -bs 64 --var-num 3 --input-num 16 --random-level 5 --ops AND,NEG,MUL,XOR,NOT --watchdog 80 --limit 5000000 my_oracle_table

qsynthesis 的使用案例如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import logging
import sys

from triton import ARCH

from qsynthesis import SimpleSymExec, TopDownSynthesizer, InputOutputOracleLevelDB
import qsynthesis
import logging
import pkg_resources
logging.basicConfig(level=logging.DEBUG)

qsynthesis.enable_logging()

RIP_ADDR = 0x40B160
RSP_ADDR = 0x800000

INSTRUCTIONS = [b'U', b'H\x89\xe5', b'H\x89}\xf8', b'H\x89u\xf0', b'H\x89U\xe8', b'H\x89M\xe0', b'L\x89E\xd8',
b'H\x8bE\xf0', b'H#E\xe0', b'H\x89\xc2', b'H\x8bE\xf0', b'H\x0bE\xe0', b'H\x0f\xaf\xd0', b'H\x8bE\xe0',
b'H\xf7\xd0', b'H#E\xf0', b'H\x89\xc1', b'H\x8bE\xf0', b'H\xf7\xd0', b'H#E\xe0', b'H\x0f\xaf\xc1',
b'H\x01\xc2', b'H\x8bE\xe0', b'H\x0f\xaf\xc0', b'H\x89\xd6', b'H!\xc6', b'H\x8bE\xf0', b'H#E\xe0',
b'H\x89\xc2', b'H\x8bE\xf0', b'H\x0bE\xe0', b'H\x0f\xaf\xd0', b'H\x8bE\xe0', b'H\xf7\xd0', b'H#E\xf0',
b'H\x89\xc1', b'H\x8bE\xf0', b'H\xf7\xd0', b'H#E\xe0', b'H\x0f\xaf\xc1', b'H\x01\xc2', b'H\x8bE\xe0',
b'H\x0f\xaf\xc0', b'H\t\xd0', b'H)\xc6', b'H\x89\xf0', b'H\x83\xe8\x01', b'H3E\xf0', b'H\x89\xc2',
b'H\x8bE\xf0', b'H#E\xe0', b'H\x89\xc1', b'H\x8bE\xf0', b'H\x0bE\xe0', b'H\x0f\xaf\xc8', b'H\x8bE\xe0',
b'H\xf7\xd0', b'H#E\xf0', b'H\x89\xc6', b'H\x8bE\xf0', b'H\xf7\xd0', b'H#E\xe0', b'H\x0f\xaf\xc6',
b'H\x01\xc1', b'H\x8bE\xe0', b'H\x0f\xaf\xc0', b'H1\xc8', b'H#E\xf0', b'H\x01\xc0', b'H)\xc2',
b'H\x89\xd0', b']', b'\xc3']

qsynthesis_version = pkg_resources.get_distribution("qsynthesis").version # 读取qsynthesis的版本
print(f"The version of qsynthesis is: {qsynthesis_version}")

def test(oracle_file):
# Perform symbolic execution of the instructions
symexec = SimpleSymExec(ARCH.X86_64) # 使用预期的体系结构对其进行初始化
symexec.initialize_register('rip', RIP_ADDR) # 初始化寄存器
symexec.initialize_register('rsp', RSP_ADDR)
for opcode in INSTRUCTIONS:
symexec.execute(opcode) # 执行给定的操作码
rax = symexec.get_register_ast("rax") # 执行指令后检索rax AST

# Load lookup tables
ltm = InputOutputOracleLevelDB.load(oracle_file) # 加载查找表数据库

# Perform Synthesis of the expression
synthesizer = TopDownSynthesizer(ltm) # 实例化表的综合大小
synt_rax, simp = synthesizer.synthesize(rax) # 触发rax表达式的合成

# Print synthesis results
print(f"simplified: {simp}")
print(f"synthesized expression: {synt_rax.pp_str}")
sz, nsz = rax.node_count, synt_rax.node_count
print(f"size: {rax.node_count} -> {synt_rax.node_count}\nsize reduction:{((sz-nsz)*100)/sz:.2f}%")
return symexec, rax, synt_rax

if __name__ == "__main__":
sx, rax, srax = test("my_oracle_table/")
  • 结果如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
The version of qsynthesis is: 0.2.0
DEBUG:qsynthesis:try synthesis lookup: (((((((((((((rsi & rcx) * (rsi | rcx))) + ((((~(rsi)) & rcx) * ((~(rcx)) & rsi))))) & ((rcx * rcx))) - (((rcx * rcx)) | (((((rsi & rcx) * (rsi | rcx))) + ((((~(rsi)) & rcx) * ((~(rcx)) & rsi)))))))) - 0x1)) ^ rsi) - ((((((rcx * rcx)) ^ (((((rsi & rcx) * (rsi | rcx))) + ((((~(rsi)) & rcx) * ((~(rcx)) & rsi)))))) & rsi) + ((((rcx * rcx)) ^ (((((rsi & rcx) * (rsi | rcx))) + ((((~(rsi)) & rcx) * ((~(rcx)) & rsi)))))) & rsi))))) [2]
DEBUG:qsynthesis:try synthesis lookup: (((((((((((rsi & rcx) * (rsi | rcx))) + ((((~(rsi)) & rcx) * ((~(rcx)) & rsi))))) & ((rcx * rcx))) - (((rcx * rcx)) | (((((rsi & rcx) * (rsi | rcx))) + ((((~(rsi)) & rcx) * ((~(rcx)) & rsi)))))))) - 0x1)) ^ rsi) [2]
DEBUG:qsynthesis:[base] candidate expr accepted: current:47 candidate:10 (best:0) => ((((rcx * rsi)) ^ (~(rsi))) ^ ((rcx * rcx)))
DEBUG:qsynthesis:Replace: (((((((((((rsi & rcx) * (rsi | rcx))) + ((((~(rsi)) & rcx) * ((~(rcx)) & rsi))))) & ((rcx * rcx))) - (((rcx * rcx)) | (((((rsi & rcx) * (rsi | rcx))) + ((((~(rsi)) & rcx) * ((~(rcx)) & rsi)))))))) - 0x1)) ^ rsi) ===> ((((rcx * rsi)) ^ (~(rsi))) ^ ((rcx * rcx)))
DEBUG:qsynthesis:try synthesis lookup: ((((((rcx * rcx)) ^ (((((rsi & rcx) * (rsi | rcx))) + ((((~(rsi)) & rcx) * ((~(rcx)) & rsi)))))) & rsi) + ((((rcx * rcx)) ^ (((((rsi & rcx) * (rsi | rcx))) + ((((~(rsi)) & rcx) * ((~(rcx)) & rsi)))))) & rsi))) [2]
DEBUG:qsynthesis:try synthesis lookup: ((((rcx * rcx)) ^ (((((rsi & rcx) * (rsi | rcx))) + ((((~(rsi)) & rcx) * ((~(rcx)) & rsi)))))) & rsi) [2]
DEBUG:qsynthesis:[base] candidate expr accepted: current:23 candidate:11 (best:0) => ((((rsi * rcx)) & rsi) ^ (((rcx * rcx)) & rsi))
DEBUG:qsynthesis:Replace: ((((rcx * rcx)) ^ (((((rsi & rcx) * (rsi | rcx))) + ((((~(rsi)) & rcx) * ((~(rcx)) & rsi)))))) & rsi) ===> ((((rsi * rcx)) & rsi) ^ (((rcx * rcx)) & rsi))
DEBUG:qsynthesis:try synthesis lookup: ((((rcx * rcx)) ^ (((((rsi & rcx) * (rsi | rcx))) + ((((~(rsi)) & rcx) * ((~(rcx)) & rsi)))))) & rsi) [2]
DEBUG:qsynthesis:expression cache found !
DEBUG:qsynthesis:Replace: ((((rcx * rcx)) ^ (((((rsi & rcx) * (rsi | rcx))) + ((((~(rsi)) & rcx) * ((~(rcx)) & rsi)))))) & rsi) ===> ((((rsi * rcx)) & rsi) ^ (((rcx * rcx)) & rsi))
simplified: True
synthesized expression: ((((((rcx * rsi)) ^ (~(rsi))) ^ ((rcx * rcx))) - ((((((rsi * rcx)) & rsi) ^ (((rcx * rcx)) & rsi)) + ((((rsi * rcx)) & rsi) ^ (((rcx * rcx)) & rsi))))))
size: 95 -> 34
size reduction:64.21%

语法分析

完善 switch 语句:

1
2
3
4
5
6
stm: ......
| 'switch' '(' expr ')' '{' case* default? '}'
......

case : 'case' expr ':' (stm ';'?)*;
default: 'default' ':' (stm ';'?)*;

核心代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
if str_1 == "switch" {
var key bool = false
vaule, _ := h.handExpr(stm.GetChild(2).(*Jsp.ExprContext))
for i := 0; i < stm.GetChildCount()-7; i++ {
num, _ := h.handExpr(stm.GetChild(i + 5).GetChild(1).(*Jsp.ExprContext))
if num.getString() == vaule.getString() || key {
key = true
if h.isBreak {
h.isBreak = false
break
}
for y := 0; y < stm.GetChild(i+5).GetChildCount()-2; y++ {
if stm.GetChild(i+5).GetChild(y+2).GetChild(0) != nil {
h.handStm(stm.GetChild(i + 5).GetChild(y + 2).(*Jsp.StmContext))
}
}
}
}
if !key {
for y := 0; y < stm.GetChild(stm.GetChildCount()-2).GetChildCount(); y++ {
if stm.GetChild(stm.GetChildCount()-2).GetChild(y).GetChild(0) != nil {
h.handStm(stm.GetChild(stm.GetChildCount() - 2).GetChild(y).(*Jsp.StmContext))
}
}
}
}

为了使 switch 语句成立,这里添加了 break continue 语句的实现:

1
2
3
4
5
6
7
8
if str_1 == "break" {
h.isBreak = true
return nil, 0
}
if str_1 == "continue" {
h.isContinue = true
return nil, 0
}
  • 另外在 for 语句和 while 语句中也添加了对 break continue 的处理
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
if str_1 == "while" {
for {
if h.isBreak {
h.isBreak = false
break
}
if h.isContinue {
h.isContinue = false
continue
}
vaule, typ := h.handExpr(stm.GetChild(2).(*Jsp.ExprContext))
if ok := h.checkExpr(vaule, typ); !ok {
break
}
h.handBlock(stm.GetChild(4).(*Jsp.BlockContext))
}
}
if str_1 == "for" {
var smts [3]*Jsp.StmContext
var index int = 0
for _, c := range stm.GetChildren() {
_, ok := c.(*Jsp.StmContext)
if tmp := fmt.Sprintf("%v", c); tmp == ";" {
index++
}
if c.GetChild(0) != nil && ok {
smts[index] = c.(*Jsp.StmContext)
}
}
if smts[0] != nil {
h.handStm(smts[0])
}
for {
if h.isBreak {
h.isBreak = false
break
}
if h.isContinue {
h.isContinue = false
continue
}
if smts[1] != nil {
vaule, typ := h.handStm(smts[1])
if ok := h.checkExpr(vaule, typ); !ok {
break
}
}
h.handBlock(stm.GetChild(stm.GetChildCount() - 1).(*Jsp.BlockContext))
if smts[2] != nil {
h.handStm(smts[2])
}
}
}

Class 的实现

对于一个类而言,需要记录的数据有:名称,变量,方法

1
2
3
4
5
type Class struct {
name string
vars []string
funcs map[string]*Fuction
}

将 Class 添加到 SymTable 中:

1
2
3
4
5
6
type SymTable struct {
Fuctions []Fuction
Classes []*Class
InFuctions inFunction
scope *Scope
}
  • 附加的内置函数如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
func NewClass(name string) *Class {
return &Class{name: name,
vars: []string{},
funcs: map[string]*Fuction{},
}
}

func (vl *SymTable) getClass(name string) (*Class, bool) {
for _, c := range vl.Classes {
if name == c.name {
return c, true
}
}
return nil, false
}

func (vl *SymTable) addClass(class *Class) {
vl.Classes = append(vl.Classes, class)
}

func (c *Class) addVar(name string) {
c.vars = append(c.vars, name)
}

func (c *Class) initFunc(m *Jsp.ConsContext) { /* 处理构造函数 */
currentT := current
defer func() {
current = currentT
}()

name := fmt.Sprintf("%v", m.GetChild(0))
current = name

params := m.GetChild(2)
var args []string
for _, param := range params.GetChildren() {
if param.GetChild(0) != nil {
arg := fmt.Sprintf("%v", param.GetChild(0))
args = append(args, arg)
}
}
method := NewFunction(name, args, m.GetChild(4).(*Jsp.BlockContext))
c.funcs[name] = method
}

func (c *Class) addFunc(m *Jsp.MethodContext) { /* 处理内置函数 */
currentT := current
defer func() {
current = currentT
}()

name := fmt.Sprintf("%v", m.GetChild(0))
current = name

params := m.GetChild(2)
var args []string
for _, param := range params.GetChildren() {
if param.GetChild(0) != nil {
arg := fmt.Sprintf("%v", param.GetChild(0))
args = append(args, arg)
}
}
method := NewFunction(name, args, m.GetChild(4).(*Jsp.BlockContext))
c.funcs[name] = method
}

func (c *Class) getFunc(name string) (*Fuction, bool) {
if _, ok := c.funcs[name]; ok {
return c.funcs[name], true
}
return nil, false
}

在基础类型对象的处理中添加 class object:

1
2
3
4
5
type CobjObject struct {
BaseObject
ClassName string
Value map[string]BaseObject
}
  • 附加的内置函数如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
func (o *CobjObject) getString() string {
str := "{"
for k, v := range o.Value {
str += fmt.Sprintf(" %s:%s,", k, v.getString())
}
str = str[:len(str)-1] + " }"

return str
}

func (o *CobjObject) addValue(name string, obj BaseObject) {
o.Value[name] = obj
}

func (o *CobjObject) setValue(name string, obj BaseObject) {
for i, _ := range o.Value {
if i == name {
o.Value[i] = obj
}
}
}

添加 class 的语法:

1
2
3
class: 'class' IDENTIFIER '{' cons? method* '}';
cons: 'constructor' '(' paramlist? ')' block;
method: IDENTIFIER '(' paramlist? ')' block;

在 listener 中添加对 class 的处理:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
func (l *MyListener) EnterClass(ctx *Jsp.ClassContext) {
name := fmt.Sprintf("%v", ctx.GetChild(1))
class := NewClass(name)

for _, c := range ctx.GetChildren() {
if con, ok := c.(*Jsp.ConsContext); ok {
class.initFunc(con)
}
if met, ok := c.(*Jsp.MethodContext); ok {
class.addFunc(met)
}
}
vl.addClass(class)
}

在 handExpr 中添加对于 this 和 class object 的处理:

1
2
3
4
5
6
7
if _, ok := exp.GetChild(0).(*Jsp.ThisContext); ok { // this
return &ThisObjetc{}, Jsp.JavaScriptParserRULE_this
}

if cobj, ok := exp.GetChild(0).(*Jsp.CobjContext); ok { // class object
return h.handExpr_cobj(cobj) /* 核心函数 */
}

另外在 handExpr 中处理 expr '.' expr 时添加对应的处理:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
} else if obj, ok := exp1.(*CobjObject); ok { // class object
if obj.Value[name2] != nil {
obj.setValue(name1, obj.Value[name2])
str := fmt.Sprintf("%v", obj.Value[name2].getString())
if str[0] == '"' {
return obj.Value[name2], Jsp.JavaScriptLexerSTRING
}
if renum.MatchString(str) {
return obj.Value[name2], Jsp.JavaScriptLexerNUMBER
}
return obj.Value[name2], -1
} else {
var args []Variant
cname := obj.ClassName

if call, ok := exp.GetChild(2).GetChild(0).(*Jsp.FuncallContext); ok {
for _, c := range call.GetChild(2).GetChildren() {
if c.GetChild(0) != nil {
value, typ := h.handExpr(c.(*Jsp.ExprContext))
args = append(args, *NewVariant("unknown", value, typ))
}
}
}
if class, ok := vl.getClass(cname); ok {
value, typ := h.handMethod(obj, class, name2, args)
return value, typ
}
return nil, -1
}
} else if this, ok := exp1.(*ThisObjetc); ok { // this
this.setValue(name2)
this.setObject(exp2)
return this, Jsp.JavaScriptParserRULE_this

核心函数 handExpr_cobj 的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
func (h *Handler) handExpr_cobj(new *Jsp.CobjContext) (BaseObject, int) {
var args []Variant
cname := fmt.Sprintf("%s", new.GetChild(1))
cobjt := &CobjObject{Value: make(map[string]BaseObject), ClassName: cname}

if class, ok := vl.getClass(cname); ok {
for _, o := range new.GetChild(3).GetChildren() {
if o.GetChild(0) != nil {
value, typ := h.handExpr(o.(*Jsp.ExprContext))
args = append(args, Variant{"unknown", value, typ})
}
}
h.handMethod(cobjt, class, "constructor", args)
}

return cobjt, Jsp.JavaScriptParserRULE_cobj
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
func (h *Handler) handMethod(cobjt *CobjObject, c *Class, name string, args []Variant) (BaseObject, int) {
currentT := current
defer func() {
vl.delVarAll()
current = currentT
}()
current = name
fun, _ := c.getFunc(name)

if len(fun.args) != len(args) {
panic("The parameters do not match")
} else {
for i, param := range fun.args {
value, typ := args[i].value, args[i].typ
vl.addVar(param, value, typ)
}
}

block := fun.fctx
for _, c := range block.GetChildren() {
if h.isBreak {
break
}
if h.isContinue {
continue
}
if c.GetChild(0) != nil {
value, typ := h.handStm(c.(*Jsp.StmContext))
if this, ok := value.(*ThisObjetc); ok {
cobjt.addValue(this.Value, this.Object)
}
if h.isReturn {
return value, typ
}
}
}

return nil, -1
}

重构基础类型

之前所有的类型都是用 string 记录:

1
2
3
4
5
type Variant struct {
name string
value string
typ int
}

JavaScript 一切皆对象,这里选择为各个基础对象添加对应的结构体:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
type BaseObject interface {
getString() string
}

type UndefObject struct {
BaseObject
Value string
}

type StringObject struct {
BaseObject
Value string
}

type NumberObject struct {
BaseObject
Value int
}

type BooleanObject struct {
BaseObject
Value bool
}

type ArrayObject struct {
BaseObject
Value []BaseObject
}

type ObjObject struct {
BaseObject
Value map[string]BaseObject
}

type VariantObject struct {
BaseObject
Value string
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
func (n *UndefObject) getString() string {
return "undefined"
}

func (s *StringObject) getString() string {
return s.Value
}

func (n *NumberObject) getString() string {
return fmt.Sprintf("%d", n.Value)
}

func (b *BooleanObject) getString() string {
if b.Value {
return "true"
}
return "false"
}

func (a *ArrayObject) getString() string {
str := "["
for _, n := range a.Value {
str += fmt.Sprintf(" %v,", n.getString())
}
str = str[:len(str)-1] + " ]"

return str
}

func (a *ArrayObject) setValue(num BaseObject) {
a.Value = append(a.Value, num)
}

func (a *ArrayObject) getValue(index int) BaseObject {
return a.Value[index]
}

func (o *ObjObject) getString() string {
str := "{"
for k, v := range o.Value {
str += fmt.Sprintf(" %s:%s,", k, v.getString())
}
str = str[:len(str)-1] + " }"

return str
}

func (o *ObjObject) addValue(name string, obj BaseObject) {
o.Value[name] = obj
}

func (o *ObjObject) setValue(name string, obj BaseObject) {
for i, _ := range o.Value {
if i == name {
o.Value[i] = obj
}
}
}

func (v *VariantObject) getString() string {
return v.Value
}

记录基础类型的结构体改变以后,许多地方的代码都要重构

首先是 handExpr 中对二元运算符的处理:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
op := fmt.Sprintf("%v", exp.GetChild(1))
exp1, typ1 := handExpr(exp.GetChild(0).(*Jsp.ExprContext))
exp2, typ2 := handExpr(exp.GetChild(2).(*Jsp.ExprContext))

......

if op == "+" {
if typ1 == Jsp.JavaScriptLexerNUMBER && typ2 == Jsp.JavaScriptLexerNUMBER {
num1, _ := strconv.Atoi(exp1.getString())
num2, _ := strconv.Atoi(exp2.getString())
return &NumberObject{Value: num1 + num2}, Jsp.JavaScriptLexerNUMBER
} else if typ1 == Jsp.JavaScriptLexerSTRING || typ2 == Jsp.JavaScriptLexerSTRING {
return &StringObject{Value: exp1.getString() + exp2.getString()}, Jsp.JavaScriptLexerSTRING
} else {
return nil, -1
}
}

......

if op == ">" {
if typ1 == Jsp.JavaScriptLexerNUMBER && typ2 == Jsp.JavaScriptLexerNUMBER {
num1, _ := strconv.Atoi(exp1.getString())
num2, _ := strconv.Atoi(exp2.getString())
return &BooleanObject{Value: num1 > num2}, Jsp.JavaScriptLexerNUMBER
} else {
return nil, -1
}
}

......

handExpr 中对一元运算符的处理:

1
2
3
4
5
6
7
8
9
10
11
if op == "++" {
if typ1 == Jsp.JavaScriptLexerNUMBER {
num1, _ := strconv.Atoi(exp1.getString())
num := &NumberObject{Value: num1 + 1}
name := fmt.Sprintf("%v", exp.GetChild(0).GetChild(0).GetChild(0))
vl.setVar(name, num, Jsp.JavaScriptLexerNUMBER)
return num, Jsp.JavaScriptLexerNUMBER
} else {
return nil, -1
}
}

handExpr 中对原始数据类型的处理:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
if exp.GetChildCount() == 1 { // funcall | num | id | str | arr | obj
if call, ok := exp.GetChild(0).(*Jsp.FuncallContext); ok {
return handExpr_funcall(call)
} else if arr, ok := exp.GetChild(0).(*Jsp.ArrContext); ok {
arrt := &ArrayObject{Value: make([]BaseObject, 0)}
for _, a := range arr.GetChildren() {
if a.GetChild(0) != nil {
num, _ := strconv.Atoi(fmt.Sprintf("%v", a.GetChild(0).GetChild(0)))
arrt.setValue(&NumberObject{Value: num})
}
}
return arrt, Jsp.JavaScriptParserRULE_arr
} else if obj, ok := exp.GetChild(0).(*Jsp.ObjContext); ok { // object
objt := &ObjObject{Value: make(map[string]BaseObject)}
for _, o := range obj.GetChildren() {
if o.GetChild(0) != nil {
name := fmt.Sprintf("%v", o.GetChild(0).GetChild(0))
value, _ := handExpr(o.GetChild(2).(*Jsp.ExprContext))
objt.addValue(name, value)
}
}
return objt, Jsp.JavaScriptParserRULE_obj
} else {
renum := regexp.MustCompile(`^\d+$`)
resym := regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)
str := fmt.Sprintf("%v", exp.GetChild(0).GetChild(0))
if str[0] == '"' { // string
return &StringObject{Value: str}, Jsp.JavaScriptLexerSTRING
}
if renum.MatchString(str) { // number
num, _ := strconv.Atoi(str)
return &NumberObject{Value: num}, Jsp.JavaScriptLexerNUMBER
}
if resym.MatchString(str) { // var
value, typ := vl.getVarByName(str)
return value, typ
}
}
}

最后则是对数组与对象的处理:

1
2
3
4
5
6
7
8
9
10
if op == "." { // expr '.' expr
name1 := fmt.Sprintf("%v", exp.GetChild(0).GetChild(0).GetChild(0))
name2 := fmt.Sprintf("%v", exp.GetChild(2).GetChild(0).GetChild(0))
if obj, ok := exp1.(*ObjObject); ok {
obj.setValue(name1, obj.Value[name2])
return obj.Value[name2], typ
} else {
panic("Object syntax error")
}
}
1
2
3
4
5
6
7
8
9
10
11
if exp.GetChildCount() == 4 { // expr '[' expr ']'
index, _ := handExpr(exp.GetChild(2).(*Jsp.ExprContext))
exp, typ := handExpr(exp.GetChild(0).(*Jsp.ExprContext))

if arr, ok := exp.(*ArrayObject); ok {
i, _ := strconv.Atoi(index.getString())
return arr.getValue(i), typ
} else {
panic("Array syntax error")
}
}

重构作用域

之前对作用域的处理主要是针对函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
func handFuncdef(name string, args []Variant) (BaseObject, int) {
var currentT = current
current = name
params, ctx := vl.getFunc(name)

defer func() {
vl.delVarAll() /* 清空SymTable中的函数参数 */
current = currentT
}()

if ctx == nil {
value, typ := vl.callInFunc(name, args)
return value, typ
}

if len(params) != len(args) {
panic("The parameters do not match")
} else {
for i, param := range params {
value, typ := args[i].value, args[i].typ
vl.addVar(param, value, typ) /* 将函数参数添加到SymTable中 */
}
}
block := ctx.GetChild(ctx.GetChildCount() - 1).(*Jsp.BlockContext)
revalue, retyp := handBlock(block)

return revalue, retyp
}
1
2
3
4
func (vl *SymTable) addVar(name string, value BaseObject, typ int) {
vl.Variants = append(vl.Variants, *NewVariant(current+"."+name, value, typ))
vl.setFuncByName(current, name)
}
1
2
3
4
5
6
7
8
func (vl *SymTable) delVarAll() {
for _, v := range vl.Variants {
substrs := strings.Split(v.name, ".")
if substrs[0] == current {
vl.Variants = vl.Variants[:len(vl.Variants)-1]
}
}
}
  • 这种写法有一个问题,那就是递归的函数也会被清除所有的参数,导致程序内部出错

重构后的作用域主要基于 Scope 结构体:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
type Scope struct {
Outer *Scope
Objects map[string]*Variant
}

func NewScope(outer *Scope) *Scope {
return &Scope{
Outer: outer,
Objects: make(map[string]*Variant),
}
}

func (s *Scope) HasName(name string) bool {
_, ok := s.Objects[name]
return ok
}

func (s *Scope) Lookup(name string) (*Scope, *Variant) {
for ; s != nil; s = s.Outer {
if obj := s.Objects[name]; obj != nil {
return s, obj
}
}
return nil, nil
}

func (s *Scope) Insert(v *Variant) (alt *Variant) {
if alt = s.Objects[v.name]; alt == nil {
s.Objects[v.name] = v
}
return
}

每一层作用域都可以用一个 Scope 结构体表示,可以用下面两个函数进行控制:

1
2
3
4
5
6
7
func (p *SymTable) enterScope() { /* 新建并进入下一层 */
p.scope = NewScope(p.scope)
}

func (p *SymTable) leaveScope(scope *Scope) { /* 返回上一层 */
p.scope = scope
}

基于 Scope 可以对 SymTable 中的其他函数进行重构:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
func (vl *SymTable) addVar(name string, value BaseObject, typ int) {
vl.scope.Insert(NewVariant(current+"."+name, value, typ))
vl.setFuncByName(current, name)
}

func (vl *SymTable) setVar(name string, value BaseObject, typ int) {
substrs := strings.Split(name, "@")
if _, obj := vl.scope.Lookup(current + "." + substrs[0]); obj != nil {
if typ == Jsp.JavaScriptParserRULE_obj {
if o, ok := obj.value.(*ObjObject); ok {
o.setValue(substrs[1], value)
}
} else {
obj.value = value
obj.typ = typ
}
} else {
panic(fmt.Sprintf("var %s undefined", name))
}
}

func (vl *SymTable) getVarByName(name string) (BaseObject, int) {
if _, obj := vl.scope.Lookup(current + "." + name); obj != nil {
return obj.value, obj.typ
}
return nil, -1
}

最后在 handBlock 中添加对 Scope 的处理即可:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
func (h *Handler) handBlock(blo *Jsp.BlockContext) (BaseObject, int) {
defer vl.leaveScope(vl.scope)
vl.enterScope()

for _, c := range blo.GetChildren() {
if h.isBreak {
break
}
if h.isContinue {
continue
}
if c.GetChild(0) != nil {
value, typ := h.handStm(c.(*Jsp.StmContext))
if h.isReturn {
return value, typ
}
}
}

return nil, -1
}

语法分析

完善 if while for 语句:

1
2
3
| 'if' '(' expr ')' block ('else' block)?
| 'while' '(' expr ')' block
| 'for' '(' stm? ';' stm? ';' stm? ')' block

对应的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
if str_1 == "return" {
value, typ = handExpr(stm.GetChild(1).(*Jsp.ExprContext))
return value, typ, true
}
if str_1 == "var" {
name := fmt.Sprintf("%v", stm.GetChild(1))
str_2 = fmt.Sprintf("%v", stm.GetChild(2))
if str_2 == "=" {
vaule, typ := handExpr(stm.GetChild(3).(*Jsp.ExprContext))
vl.addVar(name, vaule, typ)
if typ == Jsp.JavaScriptParserRULE_obj {
vl.setVarAll(name)
vl.showVarAll()
}
}
vl.showVarAll()
}
if str_1 == "if" {
vaule, typ := handExpr(stm.GetChild(2).(*Jsp.ExprContext))
if ok := checkExpr(vaule, typ); ok {
handBlock(stm.GetChild(4).(*Jsp.BlockContext))
} else if stm.GetChildCount() == 6 {
handBlock(stm.GetChild(6).(*Jsp.BlockContext))
}
}
if str_1 == "while" {
for {
vaule, typ := handExpr(stm.GetChild(2).(*Jsp.ExprContext))
if ok := checkExpr(vaule, typ); !ok {
break
}
handBlock(stm.GetChild(4).(*Jsp.BlockContext))
}
}
if str_1 == "for" {
var smts [3]*Jsp.StmContext
var index int = 0

for _, c := range stm.GetChildren() {
_, ok := c.(*Jsp.StmContext)
if tmp := fmt.Sprintf("%v", c); tmp == ";" {
index++
}
if c.GetChild(0) != nil && ok {
smts[index] = c.(*Jsp.StmContext)
}
}
if smts[0] != nil {
handStm(smts[0])
}

for {
if smts[1] != nil {
vaule, typ, _ := handStm(smts[1])
if ok := checkExpr(vaule, typ); !ok {
break
}
}
handBlock(stm.GetChild(stm.GetChildCount() - 1).(*Jsp.BlockContext))
if smts[2] != nil {
handStm(smts[2])
}
}
}

为了匹配 if 语句,需要给表达式语句添加对于条件语句的处理:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
if op == ">" {
if typ1 == Jsp.JavaScriptLexerNUMBER && typ2 == Jsp.JavaScriptLexerNUMBER {
num1, _ := strconv.Atoi(exp1)
num2, _ := strconv.Atoi(exp2)
return strconv.FormatBool(num1 > num2), Jsp.JavaScriptParserRULE_bool
} else {
return "", -1
}
}
if op == ">=" {
if typ1 == Jsp.JavaScriptLexerNUMBER && typ2 == Jsp.JavaScriptLexerNUMBER {
num1, _ := strconv.Atoi(exp1)
num2, _ := strconv.Atoi(exp2)
return strconv.FormatBool(num1 >= num2), Jsp.JavaScriptParserRULE_bool
} else {
return "", -1
}
}
if op == "<" {
if typ1 == Jsp.JavaScriptLexerNUMBER && typ2 == Jsp.JavaScriptLexerNUMBER {
num1, _ := strconv.Atoi(exp1)
num2, _ := strconv.Atoi(exp2)
return strconv.FormatBool(num1 < num2), Jsp.JavaScriptParserRULE_bool
} else {
return "", -1
}
}
if op == "<=" {
if typ1 == Jsp.JavaScriptLexerNUMBER && typ2 == Jsp.JavaScriptLexerNUMBER {
num1, _ := strconv.Atoi(exp1)
num2, _ := strconv.Atoi(exp2)
return strconv.FormatBool(num1 <= num2), Jsp.JavaScriptParserRULE_bool
} else {
return "", -1
}
}
if op == "==" {
if typ1 == Jsp.JavaScriptLexerNUMBER && typ2 == Jsp.JavaScriptLexerNUMBER {
num1, _ := strconv.Atoi(exp1)
num2, _ := strconv.Atoi(exp2)
return strconv.FormatBool(num1 == num2), Jsp.JavaScriptParserRULE_bool
} else {
return "", -1
}
}
if op == "!=" {
if typ1 == Jsp.JavaScriptLexerNUMBER && typ2 == Jsp.JavaScriptLexerNUMBER {
num1, _ := strconv.Atoi(exp1)
num2, _ := strconv.Atoi(exp2)
return strconv.FormatBool(num1 != num2), Jsp.JavaScriptParserRULE_bool
} else {
return "", -1
}
}

另外还添加了对于语句块的处理:

1
2
3
4
5
6
7
8
9
10
11
func handBlock(blo *Jsp.BlockContext) (string, int) {
for _, c := range blo.GetChildren() {
if c.GetChild(0) != nil {
value, typ, ok := handStm(c.(*Jsp.StmContext))
if ok {
return value, typ
}
}
}
return "", -1
}

为了实现对象处理,我这里对许多地方进行了修改:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
if exp.GetChildCount() == 1 {
if call, ok := exp.GetChild(0).(*Jsp.FuncallContext); ok {
return handExpr_funcall(call)
} else if obj, ok := exp.GetChild(0).(*Jsp.ObjContext); ok {
for _, o := range obj.GetChildren() {
if o.GetChild(0) != nil {
name := fmt.Sprintf("%v", o.GetChild(0).GetChild(0))
value, typ := handExpr(o.GetChild(2).(*Jsp.ExprContext))
vl.addVar("*."+name, value, typ)
}
}
return "", Jsp.JavaScriptParserRULE_obj
} else {
renum := regexp.MustCompile(`^\d+$`)
resym := regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)
str := fmt.Sprintf("%v", exp.GetChild(0).GetChild(0))
if str[0] == '"' {
return str, Jsp.JavaScriptLexerSTRING
}
if renum.MatchString(str) {
return str, Jsp.JavaScriptLexerNUMBER
}
if resym.MatchString(str) {
value, typ := vl.getVarByName(str)
return value, typ
}
}
}
  • 当匹配到 Jsp.ObjContext 时,程序会将该对象的各个条目装入动态符号表 SymTable,形成如下结构:
1
2
3
4
5
6
7
|----------SymTable--------------------------------------|
|name type value |
|--------------------------------------------------------|
|globol.*.type 20 "porsche" |
|globol.*.model 20 "911" |
|globol.*.color 20 "white" |
|--------------------------------------------------------|
  • 此时的对象条目并不能被直接使用,只有第一次使用该对象时才会再次进行初始化,然后形成如下结构:
1
2
3
4
5
6
7
8
|----------SymTable--------------------------------------|
|name type value |
|--------------------------------------------------------|
|globol.car.type 20 "porsche" |
|globol.car.model 20 "911" |
|globol.car.color 20 "white" |
|globol.car 5 {car} |
|--------------------------------------------------------|

对应的代码如下:

1
2
3
4
5
6
7
8
if str_2 == "=" {
vaule, typ := handExpr(stm.GetChild(3).(*Jsp.ExprContext))
vl.addVar(name, vaule, typ)
if typ == Jsp.JavaScriptParserRULE_obj {
vl.setVarAll(name)
vl.showVarAll()
}
}
  • 这样的处理保证了没有被引用的对象不会被 SymTable 使用,但同时它们也被永远遗留在了 SymTable 中
  • 其实有考虑过将 SymTable 中不使用的对象数据给清除掉,但考虑到后续可能会想到更好的处理方式,便没有写清除 SymTable 的代码

最后在表达式处理中添加了关于对象引用的代码:

1
2
3
4
5
6
if op == "." {
name1 := exp1[1 : len(exp1)-1]
name2 := fmt.Sprintf("%v", exp.GetChild(2).GetChild(0).GetChild(0))
value, typ := vl.getVarByName(name1 + op + name2)
return value, typ
}
  • 我这里直接将对象引用 . 当成了一种运算符号,然后强行合并到二元运算符的处理中
  • 这部分代码返回的上层代码如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
if exp1, ok := stm.GetChild(0).(*Jsp.ExprContext); ok {
value, typ = handExpr(exp1)
if str_2 == "=" {
if exp2, ok := stm.GetChild(2).(*Jsp.ExprContext); ok {
var name string
value2, typ2 := handExpr(exp2)

if exp1.GetChild(0).GetChild(0).GetChild(0) == nil {
name = fmt.Sprintf("%v", exp1.GetChild(0).GetChild(0))
} else {
name = fmt.Sprintf("%v", exp1.GetChild(0).GetChild(0).GetChild(0))
name += "."
name += fmt.Sprintf("%v", exp1.GetChild(2).GetChild(0).GetChild(0))
}
vl.setVar(name, value2, typ2)
}
}
}
  • 目前看上去有点勉强,之后会考虑优化

这是一个基于 Go 的简单 JavaScript 解释器,词法分析和语法分析都使用了 antlr4

词法分析

词法分析使用了 antlr4,脚本如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
lexer grammar JavaScriptLexer;

FUNCTION : 'function';
RETURN : 'return';
VAR : 'var';

IF : 'if';
ELSE : 'else';
WHILE : 'while';
FOR : 'for';
BREAK : 'break';
CONTINUE : 'continue';

TRUE: 'true';
FALSE: 'false';

LPAREN: '(';
RPAREN: ')';
LBRACE: '{';
RBRACE: '}';
LBRACKET: '[';
RBRACKET: ']';

NUMBER: [0-9]+('.'[0-9]+)?;
IDENTIFIER: [a-zA-Z] [a-zA-Z0-9]*;
STRING: '"' (~["\r\n] | '\\"')* '"';

COL : ':';
DOT : '.';
COMMA : ',';
SEMICOLON : ';';

ASSIGN : '=';
ADD: '+';
SUB: '-';
MUL: '*';
DIV: '/';
MOD: '%';

ADD_ASSIGN: '+=';
SUB_ASSIGN: '-=';
MUL_ASSIGN: '*=';
DIV_ASSIGN: '/=';
MOD_ASSIGN: '%=';

NOT: '!';
EQ: '==';
NEQ: '!=';
LT: '<';
GT: '>';
LTE: '<=';
GTE: '>=';

AND: '&&';
OR: '||';

WS: [ \t\r\n]+ -> skip;
COMMENT: '/*' .*? '*/' -> skip;
LINE_COMMENT: '//' ~[\r\n]* -> skip;

语法分析

语法分析使用了 antlr4,脚本如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
parser grammar JavaScriptParser;

options { tokenVocab=JavaScriptLexer; }

num : NUMBER;
id : IDENTIFIER;
str: STRING;
arr: '[' str (',' str)* ']';
key : id ':' expr;
obj: '{' key (',' key)* '}';

funcdef: 'function' IDENTIFIER '(' paramlist? ')' block;
paramlist: param (',' param)*;
param: IDENTIFIER ('=' expr)?;

funcall: IDENTIFIER '(' exprlist? ')';
exprlist: expr (',' expr)*;

program : (global)* ;

global: funcdef
| stmg
;

expr: expr ('*' | '/' | '+' | '-') expr
| expr ('<' | '>' | '==' | '<=' | '>=' | '!=') expr
| expr ('&&' | '||') expr
| '!' expr
| '-' expr
| '(' expr ')'
| id '.' funcall
| id '.' id
| funcall '.' funcall
| funcall '.' id
| funcall
| num
| id
| str
| arr
| obj
;

stmg: stm ';'?;
stm: expr ('=' | '+=' | '-=' | '*=' | '/=') expr
| 'var' IDENTIFIER ('=' expr)?
| 'return' expr
| 'break'
| 'continue'
| if
| while
| block
| expr
;

if: 'if' '(' expr ')' block ('else' block)?;
while: 'while' '(' expr ')' block;
block: '{' (stm ';'?)* '}';

语法分析

语法分析主要使用了 antlr4 中的 listener 模块,主要分析全局的函数和语句:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
package listener

import (
"fmt"

Jsp "github.com/klang/js-go/JavaScript"
)

type MyListener struct {
*Jsp.BaseJavaScriptParserListener
}

func (l *MyListener) EnterProgram(ctx *Jsp.ProgramContext) {
vl = NewSymTable()
}

func (l *MyListener) EnterFuncdef(ctx *Jsp.FuncdefContext) {
var args []string
name := fmt.Sprintf("%v", ctx.GetChild(1))
for _, v := range ctx.GetChild(3).GetChildren() {
if v.GetChild(0) != nil {
args = append(args, fmt.Sprintf("%v", v.GetChild(0)))
}
}
vl.addFunc(name, args, ctx)
vl.showFuncAll()
}

func (l *MyListener) EnterStmg(ctx *Jsp.StmgContext) {
handStm(ctx.GetChild(0).(*Jsp.StmContext))
}
  • 对于函数只是简单的将其记录在 SymTable 中
  • 对于语句则需要进行详细的处理

由于 JavaScript 是解释性语言,我这里使用了一个动态变化的 SymTable 来实时记录各个变量的数据,其结构如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
type SymTable struct {
Variants []Variant
Fuctions []Fuction
InFuctions inFunction
}

type Variant struct {
name string
value string
typ int
}

type Fuction struct {
name string
args []string
fctx *Jsp.FuncdefContext
}

type inFunction struct {
log func(string) int
}

对应的辅助函数如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
func NewSymTable() *SymTable {
sym := &SymTable{
Variants: []Variant{},
Fuctions: []Fuction{},
InFuctions: inFunction{},
}
sym.initBuildInFunc()
return sym
}

func NewVariant(name string, value string, typ int) *Variant {
return &Variant{name, value, typ}
}

func (vl *SymTable) showVarAll() {
fmt.Println("|-----------------SymTable-----------------------|")
fmt.Println("|name \t\t type \t\t value \t\t |")
fmt.Println("|------------------------------------------------|")
for _, v := range vl.Variants {
fmt.Printf("|%-15v %-15v %-15v |\n", v.name, v.typ, v.value)
}
fmt.Printf("|------------------------------------------------|\n\n")
}

func (vl *SymTable) addVar(name string, value string, typ int) {
vl.Variants = append(vl.Variants, *NewVariant(name, value, typ))
}

func (vl *SymTable) getVarByName(name string) (string, int) {
for _, v := range vl.Variants {
if v.name == name {
return v.value, v.typ
}
}
return "", -1
}

func (vl *SymTable) getvarByIndex(index int) (string, int) {
return vl.Variants[index].value, vl.Variants[index].typ
}

func (vl *SymTable) getVarlen() int {
return len(vl.Variants)
}

func (vl *SymTable) delVar() {
vl.Variants = vl.Variants[:len(vl.Variants)-1]
}

func NewFunction(name string, args []string, ctx *Jsp.FuncdefContext) *Fuction {
return &Fuction{name, args, ctx}
}

func (vl *SymTable) addFunc(name string, args []string, ctx *Jsp.FuncdefContext) {
vl.Fuctions = append(vl.Fuctions, *NewFunction(name, args, ctx))
}

func (vl *SymTable) getFunc(name string) ([]string, *Jsp.FuncdefContext) {
for _, f := range vl.Fuctions {
if f.name == name {
return f.args, f.fctx
}
}
return nil, nil
}

func (vl *SymTable) showFuncAll() {
fmt.Println("|----------Function--------------|")
fmt.Println("|name \t\t args \t\t |")
for _, v := range vl.Fuctions {
fmt.Printf("|%-15v[ ", v.name)
for _, a := range v.args {
fmt.Printf("%v ", a)
}
fmt.Println("]")
}
fmt.Printf("|--------------------------------|\n\n")
}

func (vl *SymTable) initBuildInFunc() {
vl.InFuctions.log = func(msg string) int {
fmt.Printf("[+]log: %v\n", msg)
return 0
}
}

func (vl *SymTable) callInFunc(name string, args []Variant) string {
if name == "log" {
vl.InFuctions.log(args[0].value)
}
return ""
}

负责处理语句的函数为 handStm:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
func handStm(stm *Jsp.StmContext) (string, int, bool) {
str_1 := fmt.Sprintf("%v", stm.GetChild(0))
if str_1 == "return" {
value, typ := handExpr(stm.GetChild(1).(*Jsp.ExprContext))
return value, typ, true
}
if str_1 == "var" {
name := fmt.Sprintf("%v", stm.GetChild(1))
vaule, typ := handExpr(stm.GetChild(3).(*Jsp.ExprContext))
vl.addVar(name, vaule, typ)
vl.showVarAll()
}

if exp, ok := stm.GetChild(0).(*Jsp.ExprContext); ok {
handExpr(exp)
}

return "", -1, false
}
  • 目前只支持 return 语句,var 语句和表达式语句

其中最核心的部分就是处理表达式的函数 handExpr:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
func handExpr(exp *Jsp.ExprContext) (value string, typ int) {
if exp.GetChildCount() == 3 {
op := fmt.Sprintf("%v", exp.GetChild(1))
exp1, typ1 := handExpr(exp.GetChild(0).(*Jsp.ExprContext))
exp2, typ2 := handExpr(exp.GetChild(2).(*Jsp.ExprContext))

if op == "+" {
if typ1 == Jsp.JavaScriptParserNUMBER && typ2 == Jsp.JavaScriptParserNUMBER {
num1, _ := strconv.Atoi(exp1)
num2, _ := strconv.Atoi(exp2)
return strconv.Itoa(num1 + num2), Jsp.JavaScriptParserNUMBER
} else {
if exp1[0] == '"' {
exp1 = exp1[1 : len(exp1)-1]
}
if exp2[0] == '"' {
exp2 = exp2[1 : len(exp2)-1]
}
return "\"" + exp1 + exp2 + "\"", Jsp.JavaScriptLexerSTRING
}
}
if op == "-" {
if typ1 == Jsp.JavaScriptParserNUMBER && typ2 == Jsp.JavaScriptParserNUMBER {
num1, _ := strconv.Atoi(exp1)
num2, _ := strconv.Atoi(exp2)
return strconv.Itoa(num1 - num2), Jsp.JavaScriptParserNUMBER
} else {
return "", -1
}
}
if op == "*" {
if typ1 == Jsp.JavaScriptParserNUMBER && typ2 == Jsp.JavaScriptParserNUMBER {
num1, _ := strconv.Atoi(exp1)
num2, _ := strconv.Atoi(exp2)
return strconv.Itoa(num1 * num2), Jsp.JavaScriptParserNUMBER
} else {
return "", -1
}
}
if op == "/" {
if typ1 == Jsp.JavaScriptParserNUMBER && typ2 == Jsp.JavaScriptParserNUMBER {
num1, _ := strconv.Atoi(exp1)
num2, _ := strconv.Atoi(exp2)
return strconv.Itoa(num1 / num2), Jsp.JavaScriptParserNUMBER
} else {
return "", -1
}
}
}

if exp.GetChildCount() == 2 {
op := fmt.Sprintf("%v", exp.GetChild(0))
exp1, typ1 := handExpr(exp.GetChild(1).(*Jsp.ExprContext))

if op == "-" {
if typ1 == Jsp.JavaScriptParserNUMBER {
num1, _ := strconv.Atoi(exp1)
return strconv.Itoa(-num1), Jsp.JavaScriptParserNUMBER
} else {
return "", -1
}
}
}

if exp.GetChildCount() == 1 {
if exp.GetChild(0).GetChildCount() > 1 {
return handExpr_funcall(exp.GetChild(0).(*Jsp.FuncallContext))
} else {
renum := regexp.MustCompile(`^\d+$`)
resym := regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)
str := fmt.Sprintf("%v", exp.GetChild(0).GetChild(0))
if str[0] == '"' {
return str, Jsp.JavaScriptLexerSTRING
}
if renum.MatchString(str) {
return str, Jsp.JavaScriptLexerNUMBER
}
if resym.MatchString(str) {
value, typ := vl.getVarByName(str)
return value, typ
}
}
}

return "", -1
}
  • 其中包括对二元表达式和一元表达式的处理

在表达式语句中有一个特殊的部分需要单独处理,那就是函数调用:

1
2
3
4
5
6
7
8
9
10
11
12
13
func handExpr_funcall(call *Jsp.FuncallContext) (string, int) {
var args []Variant = []Variant{}

name := fmt.Sprintf("%v", call.GetChild(0))

for _, c := range call.GetChild(2).GetChildren() {
if c.GetChild(0) != nil {
value, typ := handExpr(c.(*Jsp.ExprContext))
args = append(args, Variant{"unknown", value, typ})
}
}
return handFuncdef(name, args)
}
  • 该函数会先记录函数调用的参数,然后再进入函数定义中进行进一步的处理

处理函数定义的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
func handFuncdef(name string, args []Variant) (string, int) {
var revalue string
var retyp int
params, ctx := vl.getFunc(name)
if ctx == nil {
return vl.callInFunc(name, args), -1
}

if len(params) != len(args) {
fmt.Println("errorrrr")
} else {
for i, param := range params {
value, typ := args[i].value, args[i].typ
vl.addVar(param, value, typ)
}
}

for _, c := range ctx.GetChild(ctx.GetChildCount() - 1).GetChildren() {
if c.GetChild(0) != nil {
value, typ, ok := handStm(c.(*Jsp.StmContext))
revalue, retyp = value, typ
if ok {
break
}
}
}

for i := 0; i < len(params); i++ {
vl.delVar()
}

return revalue, retyp
}
  • 先查找记录在 SymTable 中的函数,如果没有则查找内置函数

作用域处理

AST 语法树翻译到 LLVM-IR 需要处理嵌套的词法域问题,在 LLVM-IR 中,通常使用 scope 语句来处理作用域问题

作用域属于语义解析,只有 compiler 包产生 LLVM 汇编代码时需要,因此在 compiler 包定义 Scope 管理词法域:

1
2
3
4
5
6
7
8
9
10
type Scope struct {
Outer *Scope
Objects map[string]*Object
}

type Object struct {
Name string
MangledName string
ast.Node
}
  • Scope 内的 Outer 指向外层的 Scope,比如 main 函数的外层 Scope 是文件
  • Object 表示一个命名对象,其中 Name 是实体在 uGo 代码中词法域的名字,LLName 是映射到 LLVM 汇编语言的名字,其中还有指向的 AST 节点,可以用于识别更多的信息

对应的辅助函数如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
func NewScope(outer *Scope) *Scope {
return &Scope{
Outer: outer,
Objects: make(map[string]*Object),
}
}

func (s *Scope) HasName(name string) bool {
_, ok := s.Objects[name]
return ok
}

func (s *Scope) Lookup(name string) (*Scope, *Object) {
for ; s != nil; s = s.Outer {
if obj := s.Objects[name]; obj != nil {
return s, obj
}
}
return nil, nil
}

func (s *Scope) Insert(obj *Object) (alt *Object) {
if alt = s.Objects[obj.Name]; alt != nil {
s.Objects[obj.Name] = obj
}
return
}

func (s *Scope) String() string {
var buf bytes.Buffer
fmt.Fprintf(&buf, "scope %p", s)
if s != nil && len(s.Objects) > 0 {
fmt.Fprintln(&buf)
for _, obj := range s.Objects {
fmt.Fprintf(&buf, "\t%T %s\n", obj.Node, obj.Name)
}
}
fmt.Fprintf(&buf, "}\n")
return buf.String()
}

基于 Scope,我们可以构造出 Compiler 类用于描述编译器:

1
2
3
4
5
type Compiler struct {
file *ast.File
scope *Scope
nextId int
}

通过如下两个函数来进入和退出内层的 Scope:

1
2
3
4
5
6
7
8
9
10
11
func (p *Compiler) enterScope() {
p.scope = NewScope(p.scope)
}

func (p *Compiler) leaveScope() {
p.scope = p.scope.Outer
}

func (p *Compiler) restoreScope(scope *Scope) {
p.scope = scope
}
  • 我们需要在编译文件、块语句时进入新的 Scope,以便于存储新词法域定义的新变量

至于全局变量则在编译文件时处理,具体而言就是 compileFile 函数

翻译为 LLVM IR

接下来的工作就是把 AST 翻译为 LLVM IR(这里我把类型检查的操作提前到了语法分析中,因此不需要考虑报错检查的问题)

核心函数如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
func (p *Compiler) genHeader(w io.Writer, file *ast.File) {
fmt.Fprintf(w, "; package %s\n", file.Pkg.Name)
fmt.Fprint(w, Header)
}

func (p *Compiler) genMain(w io.Writer, file *ast.File) {
if file.Pkg.Name != "main" {
return
}
for _, fn := range file.Funcs {
if fn.Name.Name == "main" {
fmt.Fprint(w, MainMain)
return
}
}
}

func (p *Compiler) Run(file *ast.File) string {
var buf bytes.Buffer

p.file = file

p.genHeader(&buf, file)
p.compileFile(&buf, file)
p.genMain(&buf, file)

return buf.String()
}

函数 compileFile 就是 LLVM IR 翻译的起点,其详细代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
func (p *Compiler) compileFile(w io.Writer, file *ast.File) {
defer p.restoreScope(p.scope)
p.enterScope()
for _, g := range file.Globals {
var mangledName = fmt.Sprintf("@ugo_%s_%s", file.Pkg.Name, g.Name.Name)
p.scope.Insert(&Object{
Name: g.Name.Name,
MangledName: mangledName,
Node: g,
})
fmt.Fprintf(w, "%s = global i32 0\n", mangledName)
}

for _, g := range file.Funcs {
var mangledName = fmt.Sprintf("declare i32 @ugo_%s_%s", file.Pkg.Name, g.Name.Name)
p.scope.Insert(&Object{
Name: g.Name.Name,
MangledName: mangledName,
Node: g,
})
fmt.Fprintf(w, "%s(%s)\n", mangledName, p.changeType(g.Type.Name))
}

if len(file.Globals) != 0 {
fmt.Fprintln(w)
}
for _, fn := range file.Funcs {
p.compileFunc(w, file, fn)
}

p.genInit(w, file)
}
  • 在函数开始之前执行 enterScope 代表进入新的命名空间
  • 在函数结束时执行 restoreScope 退出到上一级的命名空间
  • 对于一个 ugo 代码而言,File 层就是最外层的命名空间,函数和全局变量都在此记录

下面是对函数的处理:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
func (p *Compiler) compileFunc(w io.Writer, file *ast.File, fn *ast.FuncDecl) {
defer p.restoreScope(p.scope)
p.enterScope()

var mangledName = fmt.Sprintf("@ugo_%s_%s", file.Pkg.Name, fn.Name.Name)

p.scope.Insert(&Object{
Name: fn.Name.Name,
MangledName: mangledName,
Node: fn,
})

if fn.Body == nil {
fmt.Fprintf(w, "declare i32 @ugo_%s_%s()\n", file.Pkg.Name, fn.Name.Name)
return
}
fmt.Fprintln(w)

fmt.Fprintf(w, "define i32 @ugo_%s_%s() {\n", file.Pkg.Name, fn.Name.Name)
p.compileStmt(w, fn.Body)
fmt.Fprintln(w, "\tret i32 0")
fmt.Fprintln(w, "}")
}
  • 每一个函数都要使用 scope.Insert 在当前命名空间中注册函数符号
  • 基础的处理过后就交给 compileStmt 处理语句块

函数 compileStmt 的核心步骤就是处理如下几种语句:

  • VarSpec:变量定义
  • AssignStmt:赋值语句
  • BlockStmt:语句块
  • ExprStmt:表达式语句
  • RetStmt:返回语句
  • IfStmt:if 语句
  • ForStmt:for 语句

除了 if 语句,for 语句和表达式语句的实现代码都比较简单,这里直接给出:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
case *ast.VarSpec:
var localName = "0"
if stmt.Value != nil {
localName = p.compileExpr(w, stmt.Value)
}

var mangledName = fmt.Sprintf("%%local_%s.pos.%d", stmt.Name.Name, stmt.VarPos)
p.scope.Insert(&Object{
Name: stmt.Name.Name,
MangledName: mangledName,
Node: stmt,
})

fmt.Fprintf(w, "\t%s = alloca i32, align 4\n", mangledName)
fmt.Fprintf(
w, "\tstore i32 %s, i32* %s\n",
localName, mangledName,
)

case *ast.AssignStmt:
p.compileStmt_assign(w, stmt)

case *ast.BlockStmt:
defer p.restoreScope(p.scope)
p.enterScope()

for _, x := range stmt.List {
p.compileStmt(w, x)
}

case *ast.ExprStmt:
p.compileExpr(w, stmt.X)

case *ast.RetStmt:
var retValue = p.compileExpr(w, stmt.Expr)
fmt.Fprintf(w, "\tret i32 %s\n", retValue)

处理 if 语句时需要考虑各个部分的作用域范围,因此使用匿名函数立即执行的写法,在匿名函数中添加 enterScope 和 restoreScope 来控制作用域范围

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
case *ast.IfStmt:
defer p.restoreScope(p.scope)
p.enterScope()

ifPos := fmt.Sprintf("%d", p.posLine())
ifInit := p.genLabelId("if.init.line" + ifPos)
ifCond := p.genLabelId("if.cond.line" + ifPos)
ifBody := p.genLabelId("if.body.line" + ifPos)
ifElse := p.genLabelId("if.else.line" + ifPos)
ifEnd := p.genLabelId("if.end.line" + ifPos)

// br if.init
fmt.Fprintf(w, "\tbr label %%%s\n", ifInit)

// if.init
fmt.Fprintf(w, "\n%s:\n", ifInit)
func() {
defer p.restoreScope(p.scope)
p.enterScope()

if stmt.Init != nil {
p.compileStmt(w, stmt.Init)
fmt.Fprintf(w, "\tbr label %%%s\n", ifCond)
} else {
fmt.Fprintf(w, "\tbr label %%%s\n", ifCond)
}

// if.cond
{
fmt.Fprintf(w, "\n%s:\n", ifCond)
condValue := p.compileExpr(w, stmt.Cond)
if stmt.Else != nil {
fmt.Fprintf(w, "\tbr i1 %s , label %%%s, label %%%s\n", condValue, ifBody, ifElse)
} else {
fmt.Fprintf(w, "\tbr i1 %s , label %%%s, label %%%s\n", condValue, ifBody, ifEnd)
}
}

// if.body
func() {
defer p.restoreScope(p.scope)
p.enterScope()

fmt.Fprintf(w, "\n%s:\n", ifBody)
if stmt.Else != nil {
p.compileStmt(w, stmt.Body)
fmt.Fprintf(w, "\tbr label %%%s\n", ifElse)
} else {
p.compileStmt(w, stmt.Body)
fmt.Fprintf(w, "\tbr label %%%s\n", ifEnd)
}
}()

// if.else
func() {
defer p.restoreScope(p.scope)
p.enterScope()

fmt.Fprintf(w, "\n%s:\n", ifElse)
if stmt.Else != nil {
p.compileStmt(w, stmt.Else)
fmt.Fprintf(w, "\tbr label %%%s\n", ifEnd)
} else {
fmt.Fprintf(w, "\tbr label %%%s\n", ifEnd)
}
}()
}()

// end
fmt.Fprintf(w, "\n%s:\n", ifEnd)

处理 for 语句时也是同样的思路:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
case *ast.ForStmt:
defer p.restoreScope(p.scope)
p.enterScope()

forPos := fmt.Sprintf("%d", p.posLine())
forInit := p.genLabelId("for.init.line" + forPos)
forCond := p.genLabelId("for.cond.line" + forPos)
forPost := p.genLabelId("for.post.line" + forPos)
forBody := p.genLabelId("for.body.line" + forPos)
forEnd := p.genLabelId("for.end.line" + forPos)

// br for.init
fmt.Fprintf(w, "\tbr label %%%s\n", forInit)

// for.init
func() {
defer p.restoreScope(p.scope)
p.enterScope()

fmt.Fprintf(w, "\n%s:\n", forInit)
if stmt.Init != nil {
p.compileStmt(w, stmt.Init)
fmt.Fprintf(w, "\tbr label %%%s\n", forCond)
} else {
fmt.Fprintf(w, "\tbr label %%%s\n", forCond)
}

// for.cond
fmt.Fprintf(w, "\n%s:\n", forCond)
if stmt.Cond != nil {
condValue := p.compileExpr(w, stmt.Cond)
fmt.Fprintf(w, "\tbr i1 %s , label %%%s, label %%%s\n", condValue, forBody, forEnd)
} else {
fmt.Fprintf(w, "\tbr label %%%s\n", forBody)
}

// for.body
func() {
defer p.restoreScope(p.scope)
p.enterScope()

fmt.Fprintf(w, "\n%s:\n", forBody)
p.compileStmt(w, stmt.Body)
fmt.Fprintf(w, "\tbr label %%%s\n", forPost)
}()

// for.post
{
fmt.Fprintf(w, "\n%s:\n", forPost)
if stmt.Post != nil {
p.compileStmt(w, stmt.Post)
fmt.Fprintf(w, "\tbr label %%%s\n", forCond)
} else {
fmt.Fprintf(w, "\tbr label %%%s\n", forCond)
}
}
}()

// end
fmt.Fprintf(w, "\n%s:\n", forEnd)

最后是表达式语句的处理,具体而言就是 compileExpr 函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
func (p *Compiler) compileExpr(w io.Writer, expr ast.Expr) (localName string) {
switch expr := expr.(type) {
case *ast.CallExpr:
if _, obj := p.scope.Lookup(expr.FuncName.Name); obj != nil {
localName = p.genId()
varStr := fmt.Sprintf("\t%s = call i32(i32) @ugo_main_%s(",
localName, expr.FuncName.Name,
)
for i, arg := range expr.Args {
if i == 0 {
varStr += p.changeType(obj.Node.(*ast.FuncDecl).Type.Name) + " " + p.compileExpr(w, arg)
} else {
varStr += ", " + p.changeType(obj.Node.(*ast.FuncDecl).Type.Name) + " " + p.compileExpr(w, arg)
}

}
varStr += ")\n"

fmt.Fprint(w, varStr)

} else {
panic(fmt.Sprintf("func %s undefined", expr.FuncName.Name))
}
return localName
case *ast.Ident:
var varName string
if _, obj := p.scope.Lookup(expr.Name); obj != nil {
varName = obj.MangledName
} else {
panic(fmt.Sprintf("var %s undefined", expr.Name))
}

localName = p.genId()
fmt.Fprintf(w, "\t%s = load i32, i32* %s, align 4\n",
localName, varName,
)
return localName
case *ast.IdentAS:
var varName string
if _, obj := p.scope.Lookup(expr.Name); obj != nil {
varName = obj.MangledName

call, ok := expr.Offset.(*ast.CallExpr)
if ok {
varName = p.compileExpr(w, call)
}
} else {
panic(fmt.Sprintf("var %s undefined", expr.Name))
}

localName = p.genId()
fmt.Fprintf(w, "\t%s = load i32, i32* %s, align 4\n",
localName, varName,
)
return localName
case *ast.Number:
localName = p.genId()
fmt.Fprintf(w, "\t%s = %s i32 %v, %v\n",
localName, "add", `0`, expr.Value,
)
return localName
case *ast.BinaryExpr:
localName = p.genId()
switch expr.Op {
case token.ADD:
fmt.Fprintf(w, "\t%s = %s i32 %v, %v\n",
localName, "add", p.compileExpr(w, expr.X), p.compileExpr(w, expr.Y),
)
return localName
case token.SUB:
fmt.Fprintf(w, "\t%s = %s i32 %v, %v\n",
localName, "sub", p.compileExpr(w, expr.X), p.compileExpr(w, expr.Y),
)
return localName
case token.MUL:
fmt.Fprintf(w, "\t%s = %s i32 %v, %v\n",
localName, "mul", p.compileExpr(w, expr.X), p.compileExpr(w, expr.Y),
)
return localName
case token.DIV:
fmt.Fprintf(w, "\t%s = %s i32 %v, %v\n",
localName, "div", p.compileExpr(w, expr.X), p.compileExpr(w, expr.Y),
)
return localName
case token.MOD:
fmt.Fprintf(w, "\t%s = %s i32 %v, %v\n",
localName, "srem", p.compileExpr(w, expr.X), p.compileExpr(w, expr.Y),
)
return localName
case token.EQL: // ==
fmt.Fprintf(w, "\t%s = %s i32 %v, %v\n",
localName, "icmp eq", p.compileExpr(w, expr.X), p.compileExpr(w, expr.Y),
)
return localName
case token.NEQ: // !=
fmt.Fprintf(w, "\t%s = %s i32 %v, %v\n",
localName, "icmp ne", p.compileExpr(w, expr.X), p.compileExpr(w, expr.Y),
)
return localName
case token.LSS: // <
fmt.Fprintf(w, "\t%s = %s i32 %v, %v\n",
localName, "icmp slt", p.compileExpr(w, expr.X), p.compileExpr(w, expr.Y),
)
return localName
case token.LEQ: // <=
fmt.Fprintf(w, "\t%s = %s i32 %v, %v\n",
localName, "icmp sle", p.compileExpr(w, expr.X), p.compileExpr(w, expr.Y),
)
return localName
case token.GTR: // >
fmt.Fprintf(w, "\t%s = %s i32 %v, %v\n",
localName, "icmp sgt", p.compileExpr(w, expr.X), p.compileExpr(w, expr.Y),
)
return localName
case token.GEQ: // >=
fmt.Fprintf(w, "\t%s = %s i32 %v, %v\n",
localName, "icmp sge", p.compileExpr(w, expr.X), p.compileExpr(w, expr.Y),
)
return localName
default:
panic(fmt.Sprintf("unknown: %[1]T, %[1]v", expr))
}
case *ast.UnaryExpr:
if expr.Op == token.SUB {
localName = p.genId()
fmt.Fprintf(w, "\t%s = %s i32 %v, %v\n",
localName, "sub", `0`, p.compileExpr(w, expr.X),
)
return localName
}
return p.compileExpr(w, expr.X)
case *ast.ParenExpr:
return p.compileExpr(w, expr.X)

default:
panic(fmt.Sprintf("unknown: %[1]T, %[1]v", expr))
}
}

完善函数语句

对于一个函数而已,关键信息有:函数名,参数列表,返回类型,语句块

对应的 AST 结点结构如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
// 全局函数/方法
type FuncDecl struct {
FuncPos int // func 关键字位置
Params *FieldList // 传参类型
Name *Ident // 函数名
Type *Ident // 返回类型
Body *BlockStmt // 函数内的语句块
}

// 参数/属性 列表
type FieldList struct {
List []*Field
}
  • 暂时不涉及多值返回

处理函数定义

函数 parseFunc 用于解析函数定义,其初步实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
func (p *Parser) parseFunc() *ast.FuncDecl {
tokFunc := p.MustAcceptToken(token.FUNC)
tokFuncIdent := p.MustAcceptToken(token.IDENT)

p.MustAcceptToken(token.LPAREN)
funcFieldList := &ast.FieldList{}

if _, ok := p.AcceptToken(token.RPAREN); !ok {
for {
tokVal := p.MustAcceptToken(token.IDENT)
tokTyp := p.MustAcceptToken(token.IDENT)
funcField := &ast.Field{
Name: &ast.Ident{
NamePos: tokVal.Pos,
Name: tokVal.Literal,
},
Type: &ast.Ident{
NamePos: tokTyp.Pos,
Name: tokTyp.Literal,
}}
funcFieldList.List = append(funcFieldList.List, funcField)
if ok = p.SymTab.Push(funcField.Name, funcField.Type); !ok {
p.errorf(tokVal.Pos, "duplicate variable declaration: %s", tokVal.Literal)
}
if _, ok := p.AcceptToken(token.COMMA); !ok {
break
}
}
p.MustAcceptToken(token.RPAREN)
}

funcName := &ast.Ident{
NamePos: tokFuncIdent.Pos,
Name: tokFuncIdent.Literal,
}

funcType := &ast.Ident{}
if tokTypeIdent, ok := p.AcceptToken(token.IDENT); ok {
funcType = &ast.Ident{
NamePos: tokTypeIdent.Pos,
Name: tokTypeIdent.Literal,
}
}

return &ast.FuncDecl{
FuncPos: tokFunc.Pos,
Params: funcFieldList,
Name: funcName,
Type: funcType,
Body: p.parseStmt_block(),
}
}
  • 暂时不支持Go语言中多个参数共用一个类型的写法
  • 函数参数将会被暂时添加入符号表中

对于一个函数而言,需要在函数返回之后在符号表中清除其传参和临时变量,这里的处理方式为:记录语句块中传参和临时变量的个数,然后在语句块结束之后执行对应数量的 Pop

这样做有一个好处,那就是当遇到嵌套语句块时也可以区分各个变量的作用域

为此需要在 SymTab 中新添加一个条目:

1
2
3
4
5
6
type SymTab struct {
elements []*TypeSpec
block []int
top int
level int
}
  • 每一层都有一个数字来记录当前语句块中有多少个变量

修改一下 Push 和 Pop 语句,并添加一个 Pops 用于弹出全部的传参和临时变量:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
func (s *SymTab) Push(name, typ *Ident) bool {
for i := s.top - 1; i >= 0; i-- {
if s.elements[i].Name.Name == name.Name {
return false
}
}
s.elements = append(s.elements, &TypeSpec{
Name: name,
Type: typ,
})
s.top++
s.block[s.level]++
return true
}

func (s *SymTab) Pop() bool {
if s.top == 0 {
return false
}
s.top--
s.elements = s.elements[:len(s.elements)-1]
return true
}

func (s *SymTab) Pops() bool {
for s.block[s.level] > 0 {
if ok := s.Pop(); !ok {
return false
}
s.block[s.level]--
}
return true
}

添加两个辅助函数用于换层:

1
2
3
4
5
6
7
8
9
func (s *SymTab) AddLevel() {
s.level++
s.block = append(s.block, 0)
}

func (s *SymTab) SubLevel() {
s.level--
s.block = s.block[:len(s.block)-1]
}

最后修改 parseStmt_block 和 parseFunc,往其中添加换层与清空临时变量的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
func (p *Parser) parseStmt_block() *ast.BlockStmt {
block := &ast.BlockStmt{}
tokBegin := p.MustAcceptToken(token.LBRACE) // 获取'{'的位置

p.SymTab.AddLevel()

......

tokEnd := p.MustAcceptToken(token.RBRACE) // 获取'}'的位置
block.Lbrace = tokBegin.Pos
block.Rbrace = tokEnd.Pos

p.SymTab.Pops()
p.SymTab.SubLevel()

return block
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
func (p *Parser) parseFunc() *ast.FuncDecl {
tokFunc := p.MustAcceptToken(token.FUNC)
tokFuncIdent := p.MustAcceptToken(token.IDENT)

p.MustAcceptToken(token.LPAREN)
funcFieldList := &ast.FieldList{}

p.SymTab.AddLevel()

......

p.SymTab.Pops()
p.SymTab.SubLevel()

return &ast.FuncDecl{
FuncPos: tokFunc.Pos,
Params: funcFieldList,
Name: funcName,
Type: funcType,
Body: funcBody,
}
}

处理返回语句

return 语句对应的 AST 结点结构如下:

1
2
3
4
5
// 表示一个 return 语句节点
type RetStmt struct {
Ret int // return 关键字的位置
Expr Expr // 返回值表达式
}

核心函数 parseStmt_return 如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
func (p *Parser) parseStmt_return() *ast.RetStmt {
tokRet := p.MustAcceptToken(token.RETURN)
tokExp := p.parseExpr()

myTypeX := reflect.TypeOf(tokExp)
if myTypeX.String() == "*ast.Strings" || myTypeX.String() == "*ast.Bool" {
p.errorf(tokExp.Pos(), "Non-numbers cannot participate in the calculation")
}

return &ast.RetStmt{
Ret: tokRet.Pos,
Expr: tokExp,
}
}

处理函数调用

函数调用对应的 AST 结点结构如下:

1
2
3
4
5
6
7
// 表示一个函数调用
type CallExpr struct {
FuncName *Ident // 函数名字
Lparen int // '(' 位置
Args []Expr // 调用参数列表
Rparen int // ')' 位置
}

函数调用发生在表达式中,需要处理的函数为 parseExpr_primary:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
if lp, ok := p.AcceptToken(token.LPAREN); ok {
var rp token.Token
astc := &ast.CallExpr{
FuncName: &ast.Ident{
NamePos: tok.Pos,
Name: tok.Literal,
},
}
for {
if rp, ok = p.AcceptToken(token.RPAREN); ok {
break
}
arg := p.parseExpr()
astc.Args = append(astc.Args, arg)
p.AcceptToken(token.COMMA)
}
astc.Lparen = lp.Pos
astc.Rparen = rp.Pos
asti.Offset = astc

return asti
}

函数表

函数表被记录于全局函数中,这里主要是通过函数表完成与函数有关的各种错误检查

错误类型:函数重复定义

1
2
3
4
5
6
7
8
func (p *Parser) checkFuncName(name string) bool {
for _, fu := range p.file.Funcs {
if fu.Name.Name == name {
return true
}
}
return false
}
1
2
3
4
5
tokFunc := p.MustAcceptToken(token.FUNC)
tokFuncIdent := p.MustAcceptToken(token.IDENT)
if ok := p.checkFuncName(tokFuncIdent.Literal); ok {
p.errorf(tokFuncIdent.Pos, "duplicate function declaration: %s", tokFuncIdent.Literal)
}

错误类型:函数未定义

1
2
3
4
5
6
7
8
9
10
11
if lp, ok := p.AcceptToken(token.LPAREN); ok {
if ok = p.checkFuncName(tok.Literal); !ok {
p.errorf(lp.Pos, "Function name:%s not find", tok.Literal)
}
var rp token.Token
astc := &ast.CallExpr{
FuncName: &ast.Ident{
NamePos: tok.Pos,
Name: tok.Literal,
},
}

错误类型:传参不匹配

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
func (p *Parser) checkFuncArgs(name string, args []ast.Expr) bool {
for _, fu := range p.file.Funcs {
if fu.Name.Name != name {
continue
}

if len(args) != len(fu.Params.List) {
return false
}

for i, arg := range fu.Params.List {
myType := reflect.TypeOf(args[i])
if arg.Type.Name == "string" && myType.String() != "*ast.Strings" {
return false
}
if arg.Type.Name == "bool" && myType.String() != "*ast.Bool" {
return false
}
}
break
}
return true
}
1
2
3
4
5
6
7
8
9
10
11
for {
if rp, ok = p.AcceptToken(token.RPAREN); ok {
break
}
arg := p.parseExpr()
astc.Args = append(astc.Args, arg)
p.AcceptToken(token.COMMA)
}
if ok = p.checkFuncArgs(tok.Literal, astc.Args); !ok {
p.errorf(lp.Pos, "The function parameters do not match")
}

符号表

本实验使用如下结构来组织符号表:

1
2
3
4
type SymTab struct {
elements []*TypeSpec
top int
}
  • PS:目前该符号表仅用于变量,在后续实验完善函数模块后会添加函数符号

与之配套的函数如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
func NewSymTab() *SymTab {
return &SymTab{
elements: nil,
top: 0,
}
}

func (s *SymTab) Push(name, typ *Ident) bool {
for i := s.top - 1; i >= 0; i-- {
if s.elements[i].Name.Name == name.Name {
return false
}
}
s.elements = append(s.elements, &TypeSpec{
Name: name,
Type: typ,
})
s.top++
return true
}

func (s *SymTab) Pop() bool {
if s.top == 0 {
return false
}
s.top--
s.elements = s.elements[:len(s.elements)-1]
return true
}

func (s *SymTab) GetType(value *Ident) (*Ident, bool) {
for i := s.top - 1; i >= 0; i-- {
if s.elements[i].Name.Name == value.Name {
return s.elements[i].Type, true
}
}
return nil, false
}

func (s *SymTab) CheckType(value, typ *Ident) bool {
if tmp, ok := s.GetType(value); ok {
return tmp.Name == typ.Name
}
return false
}

func (s *SymTab) ShowAll() {
fmt.Println("----------SymTab-------------------------")
for i, e := range s.elements {
fmt.Printf("|@%d: {Name: %s, Type: %s}\t\t|\n", i+1, e.Name.Name, e.Type.Name)
}
fmt.Println("-----------------------------------------")
fmt.Println("")
}

错误检测

该错误检测不包括函数未定义和函数参数不匹配,后续实验会进行补充

错误类型:重复定义

该错误发生在变量定义的过程中,具体的函数为 parseStmt_var,需要在该函数中添加与符号表相关的代码:

1
2
3
4
5
6
if ok := p.SymTab.Push(&ast.Ident{
NamePos: tokIdent.Pos,
Name: tokIdent.Literal,
}, varSpec.Type); !ok {
p.errorf(tokIdent.Pos, "duplicate variable declaration: %s", tokIdent.Literal)
}

错误类型:使用时未定义

变量通常在表达式中使用,因此需要在 parseExpr_primary 中进行修改:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
case token.IDENT: // 变量 && 数组 && 结构体
p.MustAcceptToken(token.IDENT)
asti := &ast.IdentAS{
NamePos: tok.Pos,
Name: tok.Literal,
}
if _, ok := p.AcceptToken(token.LSB); ok {
asti.Offset = p.parseExpr()
p.AcceptToken(token.RSB)
}
if _, ok := p.AcceptToken(token.DOT); ok {
asti.Offset = p.parseExpr()
}

if _, ok := p.SymTab.GetType(tok.Literal); !ok {
p.errorf(tok.Pos, "unknown variable: %s", tok.Literal)
}
return asti

错误类型:未知类型

在进行类型检查之前需要先创建一个类型列表用以记录已有的类型:

1
2
3
4
type TypeTab struct {
elements []string
top int
}

相关函数可以模仿 SymTab:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
func NewTypeTab() *TypeTab {
var re = &TypeTab{
elements: make([]string, 4),
top: 4,
}

re.elements[0] = "bool"
re.elements[1] = "char"
re.elements[2] = "int"
re.elements[3] = "string"

return re
}

func (s *TypeTab) Push(typ string) bool {
for i := s.top - 1; i >= 0; i-- {
if s.elements[i] == typ {
return false
}
}
s.elements = append(s.elements, typ)
s.top++
return true
}

func (s *TypeTab) Pop() bool {
if s.top == 0 {
return false
}
s.top--
return true
}

func (s *TypeTab) CheckType(typ string) bool {
for i := s.top - 1; i >= 0; i-- {
if s.elements[i] == typ {
return true
}
}
return false
}

func (s *TypeTab) ShowAll() {
fmt.Println("----------TypeTab------------------------")
for i, e := range s.elements {
fmt.Printf("|@%d: {Type: %s}\t\t\t|\n", i+1, e)
}
fmt.Println("-----------------------------------------")
fmt.Println("")
}

类型检查发生在变量定义的过程中,也就是 parseStmt_var 函数,往其中添加有关类型检查的代码:

1
2
3
4
5
6
7
8
9
if typ, ok := p.AcceptToken(token.IDENT); ok {
varSpec.Type = &ast.Ident{
NamePos: typ.Pos,
Name: typ.Literal,
}
if ok = p.TypeTab.CheckType(typ.Literal); !ok {
p.errorf(typ.Pos, "unknown type: %s", typ.Literal)
}
}

错误类型:重复类型定义

该错误发生在 type 语句,需要在 parseStmt_type 函数中添加对应的检查代码:

1
2
3
if ok := p.TypeTab.Push(tokIdent.Literal); !ok {
p.errorf(tokIdent.Pos, "duplicate type name: %s", tokIdent.Literal)
}

错误类型:类型不匹配

目前该编译器支持4种类型(bool,char,int,string),而需要修改的地方一共有3处

用于匹配类型的函数如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
func (s *TypeTab) MatchType(left string, right interface{}) bool {
myType := reflect.TypeOf(right)
switch {
case myType.String() == "*ast.Strings":
if left != "string" {
return false
}
case myType.String() == "*ast.Number":
if left != "int" {
return false
}
case myType.String() == "*ast.Bool":
if left != "bool" {
return false
}
case myType.String() == "*ast.Char":
if left != "char" {
return false
}
}
return true
}

修改 parseStmt_var,检查变量初始化的类型是否匹配:

1
2
3
4
5
6
7
if _, ok := p.AcceptToken(token.ASSIGN); ok {
varSpec.Value = p.parseExpr()

if ok = p.TypeTab.MatchType(varSpec.Type.Name, varSpec.Value); !ok {
p.errorf(tokIdent.Pos, "The variable:%s type:%s not match", varSpec.Name.Name, varSpec.Type.Name)
}
}

修改 parseStmt_block,检查赋值语句的类型是否匹配:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
case token.DEFINE, token.ASSIGN:
p.ReadToken()
exprValueList := p.parseExprList()
if len(exprList) != len(exprValueList) {
p.errorf(tok.Pos, "unknown token: %v", tok)
}
var assignStmt = &ast.AssignStmt{
Target: make([]*ast.IdentAS, len(exprList)),
OpPos: tok.Pos,
Op: tok.Type,
Value: make([]ast.Expr, len(exprList)),
}
for i, target := range exprList {
assignStmt.Target[i] = target.(*ast.IdentAS)
assignStmt.Value[i] = exprValueList[i]

if ok := p.TypeTab.MatchType(assignStmt.Target[i].Name, assignStmt.Value[i]); !ok {
p.errorf(tok.Pos, "The variable:%s type not match", assignStmt.Target[i].Name)
}
}
block.List = append(block.List, assignStmt)

修改 parseExpr_binary,字符串和布尔不能参与运算:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
op := p.PeekToken()
if op.Type.Precedence() < prec {
return x
}

myTypeX := reflect.TypeOf(x)
if myTypeX.String() == "*ast.Strings" || myTypeX.String() == "*ast.Bool" {
p.errorf(op.Pos, "Non-numbers cannot participate in the calculation")
}

p.MustAcceptToken(op.Type)
y := p.parseExpr_binary(op.Type.Precedence() + 1)
x = &ast.BinaryExpr{OpPos: op.Pos, Op: op.Type, X: x, Y: y}

myTypeY := reflect.TypeOf(y)
if myTypeY.String() == "*ast.Strings" || myTypeY.String() == "*ast.Bool" {
p.errorf(op.Pos, "Non-numbers cannot participate in the calculation")
}

首先需要在 token.go 中添加对应的定义:

1
2
3
4
5
6
7
8
const (
......
STRINGS
ARRAY
TYPE
STRUCT
......
)
1
2
3
4
5
6
7
8
var tokens = [...]string{
......
STRINGS: "STRINGS",
ARRAY: "ARRAY",
TYPE: "type",
STRUCT: "struct",
......
}

字符串处理

在词法分析阶段,只需要添加对 “双引号” 的处理即可:

1
2
3
4
5
6
7
case r == '"':
for {
if r := p.src.Read(); r == '"' {
p.emit(token.STRINGS)
break
}
}

在语法分析阶段,字符串 AST 结点的结构可以仿造 Number:

1
2
3
4
5
6
7
// 一个字符串
type Strings struct {
ValuePos int // 字符串的开始位置
ValueEnd int // 字符串的结束位置
Value string // 字符串
ValueLen int // 字符串长度
}

字符串通常会作为表达式出现,因此需要在 parseExpr_primary 中添加与字符串有关的处理:

1
2
3
4
5
6
7
case token.STRINGS: // 字符串
tokStrings := p.MustAcceptToken(token.STRINGS)
return &ast.Strings{
ValuePos: tokStrings.Pos + 1,
ValueEnd: tokStrings.Pos + int(len(tokStrings.Literal)) - 1,
Value: tokStrings.Literal[1 : len(tokStrings.Literal)-1],
}

数组处理

数组的处理分为数组定义和数组使用这两个过程,首先需要定义 “[]” 的符号:

1
2
3
4
5
6
const (
......
LSB // [
RSB // ]
......
)
1
2
3
4
5
6
var tokens = [...]string{
......
LSB: "[",
RSB: "]",
......
}

然后在语法分析中添加对 “[]” 的识别:

1
2
3
4
case r == '[':
p.emit(token.LSB)
case r == ']':
p.emit(token.RSB)

另外还需要修改变量的 AST 结点结构:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
// 变量定义信息
type VarSpec struct {
VarPos int // var 关键字位置
Name *IdentArr // 变量名字
Type *Ident // 变量类型, 可省略
Value Expr // 变量表达式
}

// 一个变量 && 数组
type IdentArr struct {
NamePos int // 字符位置
Name string // 字符
Offset Expr // 数组位置
}
  • 其实就是在变量的基础上添加了一个 Expr 用于记录数组偏移

数组定义的格式如下:

1
var arrName [arrLen]arrType

对比于普通变量的定义多了 [arrLen]arrType,因此直接对 parseStmt_var 函数进行修改:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
func (p *Parser) parseStmt_var() *ast.VarSpec {
tokVar := p.MustAcceptToken(token.VAR)
tokIdent := p.MustAcceptToken(token.IDENT)

var varSpec = &ast.VarSpec{
VarPos: tokVar.Pos,
}

varSpec.Name = &ast.IdentArr{
NamePos: tokIdent.Pos,
Name: tokIdent.Literal,
Offset: nil,
}

if typ, ok := p.AcceptToken(token.IDENT); ok {
varSpec.Type = &ast.Ident{
NamePos: typ.Pos,
Name: typ.Literal,
}
}

if _, ok := p.AcceptToken(token.LSB); ok { /* 若检测到'['则为Offset添加数据 */
varSpec.Name.Offset = p.parseExpr()
p.AcceptToken(token.RSB)

typ, _ := p.AcceptToken(token.IDENT)
varSpec.Type = &ast.Ident{
NamePos: typ.Pos,
Name: typ.Literal,
}
}

if _, ok := p.AcceptToken(token.ASSIGN); ok {
varSpec.Value = p.parseExpr()
}

p.AcceptToken(token.SEMICOLON)
return varSpec
}

数组多用于表达式中,需要在 parseExpr_primary 添加对于数组的处理:

1
2
3
4
5
6
7
8
9
10
11
case token.IDENT: // 变量 && 数组
p.MustAcceptToken(token.IDENT)
asti := &ast.IdentArr{
NamePos: tok.Pos,
Name: tok.Literal,
}
if _, ok := p.AcceptToken(token.LSB); ok { /* 检测'['判断是否为数组 */
asti.Offset = p.parseExpr()
p.AcceptToken(token.RSB)
}
return asti

为了与赋值语句匹配,最后需要修改 parseStmt 和 parseStmt_block 的部分代码:

1
2
3
4
5
6
7
8
9
10
var assignStmt = &ast.AssignStmt{
Target: make([]*ast.IdentArr, len(exprList)),
OpPos: tok.Pos,
Op: tok.Type,
Value: make([]ast.Expr, len(exprList)),
}
for i, target := range exprList {
assignStmt.Target[i] = target.(*ast.IdentArr)
assignStmt.Value[i] = exprValueList[i]
}

结构体处理

在词法分析阶段,需要添加对 “type” 和 “struct” 这两个关键词的匹配:

1
2
3
4
5
6
var tokens = [...]string{
.....
TYPE: "type",
STRUCT: "struct",
.....
}

另外需要将它们添加入关键词列表中:

1
2
3
4
5
6
var keywords = map[string]TokenType{
......
"type": TYPE,
"struct": STRUCT,
......
}

结构体的 AST 结点结构如下:

1
2
3
4
5
6
7
8
9
10
11
12
// 类型定义
type TypeSpec struct {
Name *Ident // 类型名
Type *Ident // 类型
}

// 结构体定义
type StructType struct {
TypePos int // type 关键字位置
Name *Ident // 结构体名字
Types []*TypeSpec // 类型列表
}

结构体定义发生在全局,需要先修改 parseFile 函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
func (p *Parser) parseFile() {
p.file = &ast.File{
Filename: p.Filename(),
Source: p.Source(),
}

p.file.Pkg = p.parsePackage()

for {
switch tok := p.PeekToken(); tok.Type {
case token.EOF:
return
case token.ERROR:
panic(tok)
case token.SEMICOLON:
p.AcceptTokenList(token.SEMICOLON)
case token.FUNC:
p.file.Funcs = append(p.file.Funcs, p.parseFunc())
case token.VAR: /* 全局变量处理 */
p.file.Globals = append(p.file.Globals, p.parseStmt_var())
case token.TYPE: /* 结构体处理 */
p.file.Types = append(p.file.Types, p.parseStmt_type())
default:
p.errorf(tok.Pos, "unknown token: %v", tok)
}
}
}

核心函数 parseStmt_type 如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
func (p *Parser) parseStmt_type() *ast.StructType {
tokpos := p.MustAcceptToken(token.TYPE)
tokIdent := p.MustAcceptToken(token.IDENT)

var StructType = &ast.StructType{
TypePos: tokpos.Pos,
}

StructType.Name = &ast.Ident{
NamePos: tokIdent.Pos,
Name: tokIdent.Literal,
}

switch tok := p.PeekToken(); tok.Type {
case token.IDENT:
tokty := p.MustAcceptToken(token.IDENT)
StructType.Types = append(StructType.Types, &ast.TypeSpec{
Name: nil,
Type: &ast.Ident{
NamePos: tokty.Pos,
Name: tokty.Literal,
},
})
case token.STRUCT:
if _, ok := p.AcceptToken(token.STRUCT); ok {
p.AcceptToken(token.LBRACE)

for {
if _, ok := p.AcceptToken(token.RBRACE); ok {
break
}
if toks, ok := p.AcceptTokenList(token.IDENT); ok {
StructType.Types = append(StructType.Types, &ast.TypeSpec{
Name: &ast.Ident{
NamePos: toks[0].Pos,
Name: toks[0].Literal,
},
Type: &ast.Ident{
NamePos: toks[1].Pos,
Name: toks[1].Literal,
},
})
}
p.ReadToken()
}
} else {
panic("Grammatical parsing errors")
}
default:
p.errorf(tok.Pos, "unknown tok: type=%v, lit=%q", tok.Type, tok.Literal)
}

p.AcceptToken(token.SEMICOLON)
return StructType
}
  • 这里分别实现了 type 的两种用法

结构体多用于表达式中,需要在 parseExpr_primary 添加对于结构体的处理:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
case token.IDENT: // 变量 && 数组 && 结构体
p.MustAcceptToken(token.IDENT)
asti := &ast.IdentArr{
NamePos: tok.Pos,
Name: tok.Literal,
}
if _, ok := p.AcceptToken(token.LSB); ok {
asti.Offset = p.parseExpr()
p.AcceptToken(token.RSB)
}
if _, ok := p.AcceptToken(token.DOT); ok {
asti.Offset = p.parseExpr()
}
return asti

变量和作用域

最小µGo的编译器中没有等号,需要在 (p *Lexer) run() (tokens []token.Token) 中添加有关等号的处理:

1
2
3
4
5
6
7
8
case r == '=': // =, ==
switch p.src.Read() {
case '=':
p.emit(token.EQL)
default:
p.src.Unread()
p.emit(token.ASSIGN)
}

本次实验需要完成语法分析中的:“变量声明”,“嵌套语句块”,“赋值语句”

在语法分析中对于语句的处理如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
func (p *Parser) parseStmt() ast.Stmt {
switch tok := p.PeekToken(); tok.Type {
case token.EOF:
return nil
case token.ERROR:
p.errorf(tok.Pos, "invalid token: %s", tok.Literal)
case token.SEMICOLON:
p.AcceptTokenList(token.SEMICOLON)
return nil
case token.LBRACE: // 语句块
return p.parseStmt_block()
default:
exprList := p.parseExprList() // 表达式列表 exprList
switch tok := p.PeekToken(); tok.Type {
case token.SEMICOLON, token.LBRACE: // exprList;
if len(exprList) != 1 {
p.errorf(tok.Pos, "unknown token: %v", tok.Type)
}
return &ast.ExprStmt{
X: exprList[0],
}
case token.DEFINE, token.ASSIGN:
// exprList := exprList; && exprList = exprList;
p.ReadToken()
exprValueList := p.parseExprList()
if len(exprList) != len(exprValueList) {
p.errorf(tok.Pos, "unknown token: %v", tok)
}
var assignStmt = &ast.AssignStmt{
Target: make([]*ast.Ident, len(exprList)),
OpPos: tok.Pos,
Op: tok.Type,
Value: make([]ast.Expr, len(exprList)),
}
for i, target := range exprList {
assignStmt.Target[i] = target.(*ast.Ident)
assignStmt.Value[i] = exprValueList[i]
}
return assignStmt
default:
p.errorf(tok.Pos, "unknown token: %v", tok)
}
}

panic("unreachable")
}
  • 对于 default 分支来说有如下3种情况:
1
2
3
exprList ;
exprList := exprList;
exprList = exprList;
  • 然后根据这3种情况分别生成赋值语句结点,该结点的 AST 结构如下:
1
2
3
4
5
6
7
// 一个赋值语句结点
type AssignStmt struct {
Target []*Ident // 要赋值的目标
OpPos int // ':=' 的位置
Op token.TokenType // '=' or ':='
Value []Expr // 值
}

语句块的 AST 结构体定义如下:

1
2
3
4
5
6
// 块语句
type BlockStmt struct {
Lbrace int // '{'的位置
List []Stmt
Rbrace int // '}'的位置
}

核心函数为 parseStmt_block,实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
func (p *Parser) parseStmt_block() *ast.BlockStmt {
block := &ast.BlockStmt{}
tokBegin := p.MustAcceptToken(token.LBRACE) // 获取'{'的位置

Loop:
for {
switch tok := p.PeekToken(); tok.Type {
case token.EOF:
break Loop
case token.ERROR:
p.errorf(tok.Pos, "invalid token: %s", tok.Literal)
case token.SEMICOLON:
p.AcceptTokenList(token.SEMICOLON)
case token.RBRACE: // }
break Loop
default:
block.List = append(block.List, p.parseStmt_expr()) // 表达式中可能有子语句块
}
}

tokEnd := p.MustAcceptToken(token.RBRACE) // 获取'}'的位置
block.Lbrace = tokBegin.Pos
block.Rbrace = tokEnd.Pos
return block
}
  • break label 这种语法可以一次性跳出 for 和 switch

用于管理变量的 AST 结构体如下:

1
2
3
4
5
6
7
// 变量信息
type VarSpec struct {
VarPos int // var 关键字位置
Name *Ident // 变量名字
Type *Ident // 变量类型, 可省略
Value Expr // 变量表达式
}

变量可以为全局也可以存在于语句块中,因此需要将函数 parseStmt_var 添加入 parseStmt 和 parseStmt_block 中

核心函数 parseStmt_var 实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
func (p *Parser) parseStmt_var() *ast.VarSpec {
tokVar := p.MustAcceptToken(token.VAR)
tokIdent := p.MustAcceptToken(token.IDENT)

var varSpec = &ast.VarSpec{
VarPos: tokVar.Pos,
}

varSpec.Name = &ast.Ident{
NamePos: tokIdent.Pos,
Name: tokIdent.Literal,
}

if typ, ok := p.AcceptToken(token.IDENT); ok {
varSpec.Type = &ast.Ident{
NamePos: typ.Pos,
Name: typ.Literal,
}
}

if _, ok := p.AcceptToken(token.ASSIGN); ok {
varSpec.Value = p.parseExpr()
}
p.AcceptToken(token.SEMICOLON)
return varSpec
}
  • 利用之前实现的 API 获取 AST 结点的关键信息,最后生成并返回该 AST 结点

if分支和for循环

在 token 类型和注册的关键字中添加 if for,以保证词法分析可以顺利匹配

if 语句结点和 for 语句结点的定义如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
// 表示一个 if 语句节点
type IfStmt struct {
If int // if 关键字的位置
Init Stmt // 初始化语句
Cond Expr // if 条件, *BinaryExpr
Body *BlockStmt // if 为真时对应的语句列表
Else Stmt // else 对应的语句
}

// 表示一个 for 语句节点
type ForStmt struct {
For int // for 关键字的位置
Init Stmt // 初始化语句
Cond Expr // 条件表达式
Post Stmt // 迭代语句
Body *BlockStmt // 循环对应的语句列表
}

在语法分析中与 if for 语句只能在语句块中,因此需要添加入 parseStmt_block:

1
2
3
4
case token.IF:
block.List = append(block.List, p.parseStmt_if())
case token.FOR:
block.List = append(block.List, p.parseStmt_for())

对于 if 语句而言,有如下几种情况:

1
2
3
4
if x > 0 {}
if x := 1; x > 0 {}
if x > 0 {} else {}
if x := 1; x > 0 {} else {}

对应的处理函数如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
func (p *Parser) parseStmt_if() *ast.IfStmt {
tokIf := p.MustAcceptToken(token.IF)
ifStmt := &ast.IfStmt{
If: tokIf.Pos,
}
stmt := p.parseStmt()
if _, ok := p.AcceptToken(token.SEMICOLON); ok {
ifStmt.Init = stmt
ifStmt.Cond = p.parseExpr()
ifStmt.Body = p.parseStmt_block()
} else {
ifStmt.Init = nil
if cond, ok := stmt.(*ast.ExprStmt); ok {
ifStmt.Cond = cond.X
} else {
p.errorf(tokIf.Pos, "if cond expect expr: %#v", stmt)
}
ifStmt.Body = p.parseStmt_block()
}

if _, ok := p.AcceptToken(token.ELSE); ok {
switch p.PeekToken().Type {
case token.IF: // else if
ifStmt.Else = p.parseStmt_if()
default:
ifStmt.Else = p.parseStmt_block()
}
}

return ifStmt
}

对于 for 语句而言,有如下几种情况:

1
2
3
for {}
for x > 10 {}
for x := 0; x < 10; x = x+1 {}

对应的处理函数如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
func (p *Parser) parseStmt_for() *ast.ForStmt {
tokFor := p.MustAcceptToken(token.FOR)

forStmt := &ast.ForStmt{
For: tokFor.Pos,
}

// for {}
if _, ok := p.AcceptToken(token.LBRACE); ok {
p.UnreadToken()
forStmt.Body = p.parseStmt_block()
return forStmt
}

// for Cond {}
// for Init?; Cond?; Post? {}

// for ; ...
if _, ok := p.AcceptToken(token.SEMICOLON); ok {
forStmt.Init = nil

// for ;; ...
if _, ok := p.AcceptToken(token.SEMICOLON); ok {
if _, ok := p.AcceptToken(token.LBRACE); ok {
// for ;; {}
p.UnreadToken()
forStmt.Body = p.parseStmt_block()
return forStmt
} else {
// for ; ; postStmt {}
forStmt.Post = p.parseStmt()
forStmt.Body = p.parseStmt_block()
return forStmt
}
} else {
// for ; cond ; ... {}
forStmt.Cond = p.parseExpr()
p.MustAcceptToken(token.SEMICOLON)
if _, ok := p.AcceptToken(token.LBRACE); ok {
// for ; cond ; {}
p.UnreadToken()
forStmt.Body = p.parseStmt_block()
return forStmt
} else {
// for ; cond ; postStmt {}
forStmt.Post = p.parseStmt()
forStmt.Body = p.parseStmt_block()
return forStmt
}
}
} else {
stmt := p.parseStmt()

if _, ok := p.AcceptToken(token.LBRACE); ok {
// for cond {}
p.UnreadToken()
if expr, ok := stmt.(ast.Expr); ok {
forStmt.Cond = expr
}
forStmt.Body = p.parseStmt_block()
return forStmt
} else {
// for init;
p.MustAcceptToken(token.SEMICOLON)
forStmt.Init = stmt

// for ;; ...
if _, ok := p.AcceptToken(token.SEMICOLON); ok {
if _, ok := p.AcceptToken(token.LBRACE); ok {
// for ;; {}
p.UnreadToken()
forStmt.Body = p.parseStmt_block()
return forStmt
} else {
// for ; ; postStmt {}
forStmt.Post = p.parseStmt()
forStmt.Body = p.parseStmt_block()
return forStmt
}
} else {
// for ; cond ; ... {}
forStmt.Cond = p.parseExpr()
p.MustAcceptToken(token.SEMICOLON)
if _, ok := p.AcceptToken(token.LBRACE); ok {
// for ; cond ; {}
p.UnreadToken()
forStmt.Body = p.parseStmt_block()
return forStmt
} else {
// for ; cond ; postStmt {}
forStmt.Post = p.parseStmt()
forStmt.Body = p.parseStmt_block()
return forStmt
}
}
}
}
}

赋值语句

赋值语句出现在语句块中,直接修改 parseStmt_block 函数即可

核心思路和 parseStmt 中关于变量初始化的处理一样:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
default:
// exprList ;
// exprList := exprList;
// exprList = exprList;
exprList := p.parseExprList()
switch tok := p.PeekToken(); tok.Type {
case token.SEMICOLON:
if len(exprList) != 1 {
p.errorf(tok.Pos, "unknown token: %v", tok.Type)
}
block.List = append(block.List, &ast.ExprStmt{
X: exprList[0],
})
case token.DEFINE, token.ASSIGN:
p.ReadToken()
exprValueList := p.parseExprList()
if len(exprList) != len(exprValueList) {
p.errorf(tok.Pos, "unknown token: %v", tok)
}
var assignStmt = &ast.AssignStmt{
Target: make([]*ast.Ident, len(exprList)),
OpPos: tok.Pos,
Op: tok.Type,
Value: make([]ast.Expr, len(exprList)),
}
for i, target := range exprList {
assignStmt.Target[i] = target.(*ast.Ident)
assignStmt.Value[i] = exprValueList[i]
}
block.List = append(block.List, assignStmt)
default:
p.errorf(tok.Pos, "unknown token: %v", tok)
}
}