1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
| import torch from monai.transforms import Compose, RandHistogramShiftD, Flipd, Rotate90d import matplotlib.pyplot as plt import SimpleITK as sitk
KEYS = ("image", "label") class aug(): def __init__(self): self.random_rotated = Compose([ Rotate90d(KEYS, k=1, spatial_axes=(2,3),allow_missing_keys=True), Flipd(KEYS, spatial_axis=(1,2,3),allow_missing_keys=True), RandHistogramShiftD(KEYS, prob=1, num_control_points=30, allow_missing_keys=True), ]) def forward(self,x): x = self.random_rotated(x) return x
def save(before_x, after_x, new_path,new_name=""): after_x = after_x[0, 0,...] if new_name=="image": ct = sitk.ReadImage(before_x, sitk.sitkInt16) else: ct = sitk.ReadImage(before_x, sitk.sitkUInt8) predict_seg = sitk.GetImageFromArray(after_x) predict_seg.SetDirection(ct.GetDirection()) predict_seg.SetOrigin(ct.GetOrigin()) predict_seg.SetSpacing(ct.GetSpacing())
sitk.WriteImage(predict_seg,new_path)
if __name__ == "__main__": image = r"D:\MyData\3Dircadb1_fusion_date\image_2.nii" label = r"D:\MyData\3Dircadb1_fusion_date\liver_2.nii" new_path = r"D:\MyData\3Dircadb1_fusion_date\image_0.nii" new_path1 = r"D:\MyData\3Dircadb1_fusion_date\liver_1.nii"
ct = sitk.ReadImage(image) ct1 = sitk.GetArrayFromImage(ct) seg = sitk.ReadImage(label) seg1 = sitk.GetArrayFromImage(seg)
ct = ct1[None, None,...] seg = seg1[None, None,...]
ct = torch.from_numpy(ct) seg = torch.from_numpy(seg) m = {"image": ct, "label":seg} augs = aug() print(m["image"].shape) data_dict= augs.forward(m)
save(image, data_dict["image"], new_path, "image") save(label, data_dict["label"], new_path1, "label")
print(data_dict["image"].shape) plt.subplots(1, 3) plt.subplot(1, 3, 1); plt.imshow(ct1[66,...]) plt.subplot(1, 3, 2); plt.imshow(data_dict["image"][0,0, 66,...]) plt.subplot(1, 3, 3); plt.imshow(data_dict["label"][0,0, 66,...]) plt.show()
|