本笔记来源于B站Up主: 有Li 的影像组学系列教学视频
本节(30)主要介绍: SMOTE解决数据不平衡的问题
SMOTE基本介绍
SMOTE (Synthetic Minority Over-sampling Technique) 采样过程:
找到少数类中的某个样本Xi的k个近邻
随机选取一个Xi的近邻Xi(nn), 并随机生成一个介于0-1之间的数ζ1,合成的新样本Xnew = Xi + ζ1*Xi(nn)
对每一个随机选出的Xi的近邻重复步骤2
对少数类中的每个样本重复上述步骤
例:如果样本数据是一维的,一个样本是0,它有一个近邻是1,那么合成样本将是一个0-1之间的数。
代码实现
1、导入包
# pip install -U imbalanced-learn # for the first time
from sklearn.datasets import make_classification
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from collections import Counter
from imblearn.over_sampling import SMOTE
2、准备数据
X, y = make_classification(n_classes = 2, class_sep = 2,
weights = [0.9, 0.1], n_informative = 2,
n_redundant = 0, flip_y = 0,
n_features = 2, n_clusters_per_class = 1,
n_samples = 100, random_state =1)
print(Counter(y))
# > Counter({0: 90, 1: 10})
3、探索数据
plt.figure()
sns.scatterplot(X[:,0],X[:,1],hue = y)
plt.show()
Output:
4、制造数据
smo = SMOTE(random_state = 42)
X_smo, y_smo = smo.fit_sample(X, y)
print(Counter(y_smo))
# > Counter({0: 90, 1: 90})
5、显示制造的数据
plt.figure()
sns.scatterplot(X_smo[:,0],X_smo[:,1], hue = y_smo, palette = 'Accent')
sns.scatterplot(X[:,0],X[:,1],hue = y)
plt.show()
Output:
注意:这种方法只能用于训练集,不能用于测试集(test subset)。