利用3D标签,生成RLE标签编码,并保存到csv文件

本文最后更新于:2023年4月7日 下午

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# coding:utf-8from glob import globimport osimport SimpleITK as sitkfrom pathlib import Pathimport numpy as npimport imageioimport pandas as pd
def rle_encode(mask, bg = 0) -> dict:
vec = mask.flatten()
nb = len(vec)
where = np.flatnonzero
starts = np.r_[0, where(~np.isclose(vec[1:], vec[:-1], equal_nan=True)) + 2]
lengths = np.diff(np.r_[starts, nb])
values = vec[starts]
assert len(starts) == len(lengths) == len(values)
rle = {}
for start, length, val in zip(starts, lengths, values):
if val == bg:
continue
rle[val] = rle.get(val, []) + [str(start), length]
# post-processing
rle = {lb: " ".join(map(str, id_lens)) for lb, id_lens in rle.items()}
return rle

def generate_rel(LABELS, path):
preds = []

for i in range(len(path)):
file = path[i]
file_name = file.split("\\")[-1].split("_seg")[0]
case = file_name.split("_")[0]
print("case:{}, file_name:{}".format(case, file_name))
seg = sitk.ReadImage(file)
seg = sitk.GetArrayFromImage(seg)
for j in range(seg.shape[0]):
if j>=0 and j<9:
number = str(0)+str(0)+str(0)+str(j+1)
elif j>=9 and j<99:
number = str(0)+str(0) + str(j+1)
else:
number = str(0) + str(j+1)
name = file_name+"_slice_"+number
output = seg[j, ...]
Snapshot_img = np.zeros(shape=(seg.shape[1],seg.shape[2],3), dtype=np.uint8) # png设置为3通道
Snapshot_img[:, :, 0][np.where(output == 1)] = 1 #我们也有3个标签,其中值分别为1,2,3,所以我们需要给每个标签都赋予不同的通道
Snapshot_img[:, :, 1][np.where(output == 2)] = 1
Snapshot_img[:, :, 2][np.where(output == 3)] = 1
rle_lb = rle_encode(Snapshot_img[:, :, 0]) if np.sum(Snapshot_img[:, :, 0]) > 1 else {}
rle_sb = rle_encode(Snapshot_img[:, :, 1]) if np.sum(Snapshot_img[:, :, 1]) > 1 else {}
rle_sto = rle_encode(Snapshot_img[:, :, 2]) if np.sum(Snapshot_img[:, :, 2]) > 1 else {}
index = (0,1,2)
rel = [rle_lb, rle_sb, rle_sto]
preds += [{"id": name, "class": lb, "predicted": rle.get(1, "")} for i, rle, lb in zip(index, rel, LABELS)]
df_pred = pd.DataFrame(preds)
df_pred.to_csv("submission.csv", index=False)

if __name__ == "__main__":
pred_file = glob(r"D:\compation\kaggle\3D_preprocess\a\*") # 获取到该文件夹下所有的标签(3D nii文件)
LABELS = ("large_bowel", "small_bowel", "stomach")
generate_rel(LABELS, pred_file)

结果:

 


打赏支持
“如果你觉得我的文章不错,不妨鼓励我继续写作。”

利用3D标签,生成RLE标签编码,并保存到csv文件
https://dreamoneyou.github.io/2022/利用3D标签,生成RLE标签编码,并保存到csv文件/
作者
九叶草
发布于
2022年5月6日
更新于
2023年4月7日
许可协议