三维医学图像深度学习,数据增强方法(monai)RandHistogramShiftD,Flipd,Rotate90d

本文最后更新于:2023年4月7日 下午

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#coding:utf-8
import torch
from monai.transforms import Compose, RandHistogramShiftD, Flipd, Rotate90d
import matplotlib.pyplot as plt
import SimpleITK as sitk
# start a chain of transforms
KEYS = ("image", "label")
class aug():
def __init__(self):
self.random_rotated = Compose([
Rotate90d(KEYS, k=1, spatial_axes=(2,3),allow_missing_keys=True),
Flipd(KEYS, spatial_axis=(1,2,3),allow_missing_keys=True),
RandHistogramShiftD(KEYS, prob=1, num_control_points=30, allow_missing_keys=True),
# ToTensorD(KEYS),
])
def forward(self,x):
x = self.random_rotated(x)
return x

# start a dataset
def save(before_x, after_x, new_path,new_name=""):
after_x = after_x[0, 0,...]
if new_name=="image":
ct = sitk.ReadImage(before_x, sitk.sitkInt16)
else:
ct = sitk.ReadImage(before_x, sitk.sitkUInt8)
predict_seg = sitk.GetImageFromArray(after_x)
predict_seg.SetDirection(ct.GetDirection())
predict_seg.SetOrigin(ct.GetOrigin())
predict_seg.SetSpacing(ct.GetSpacing())

sitk.WriteImage(predict_seg,new_path)


if __name__ == "__main__":
image = r"D:\MyData\3Dircadb1_fusion_date\image_2.nii" # 原图
label = r"D:\MyData\3Dircadb1_fusion_date\liver_2.nii" #标签
new_path = r"D:\MyData\3Dircadb1_fusion_date\image_0.nii" #增强后的原图
new_path1 = r"D:\MyData\3Dircadb1_fusion_date\liver_1.nii" #增强后的标签

ct = sitk.ReadImage(image)
ct1 = sitk.GetArrayFromImage(ct)
seg = sitk.ReadImage(label)
seg1 = sitk.GetArrayFromImage(seg)

ct = ct1[None, None,...]
seg = seg1[None, None,...]

ct = torch.from_numpy(ct)
seg = torch.from_numpy(seg)
m = {"image": ct,
"label":seg}
augs = aug()
print(m["image"].shape)
data_dict= augs.forward(m)

save(image, data_dict["image"], new_path, "image")
save(label, data_dict["label"], new_path1, "label")


print(data_dict["image"].shape)
plt.subplots(1, 3)
plt.subplot(1, 3, 1);
plt.imshow(ct1[66,...])
plt.subplot(1, 3, 2);
plt.imshow(data_dict["image"][0,0, 66,...])
plt.subplot(1, 3, 3);
plt.imshow(data_dict["label"][0,0, 66,...])
plt.show()

打赏支持
“如果你觉得我的文章不错,不妨鼓励我继续写作。”

三维医学图像深度学习,数据增强方法(monai)RandHistogramShiftD,Flipd,Rotate90d
https://dreamoneyou.github.io/2022/三维医学图像深度学习,数据增强方法(monai):RandHistogramShiftD, Flipd, Rotate90d/
作者
九叶草
发布于
2022年3月14日
更新于
2023年4月7日
许可协议