-
Notifications
You must be signed in to change notification settings - Fork 1
/
predict.py
287 lines (238 loc) · 12.8 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
from __future__ import print_function, division
import sys
sys.path.insert(0, 'lib')
import numpy as np
import random
import pydicom
import os
import matplotlib.pyplot as plt
import pickle
import math
import pydicom
from shutil import copyfile
import nibabel as nib
import scipy.ndimage as ndimage
from scipy.stats import pearsonr, spearmanr
from utils import make_giant_mat, make_dictionary, make_echo_dict
from difference_map_utils import make_difference
from cluster_utils import threshold_diffmaps, strip_empty_lines
from lesion_utils import *
from inference_utils import run_inference
from make_inference_csv import *
from compare_segmentations import get_dice_scores, get_jaccard_indices, compare_segmentation_masks, compare_region_means, compare_region_changes
from loss_functions import coefficient_of_variation
from figure_utils import plot_mean_val_comparisons
import zipfile
from calculate_t2 import fit_t2
from segmentation_refinement import *
from projection import *
from inference_utils import *
import time
import shutil
from skimage.transform import resize
print("STARTING NOW")
input_dir = '/workspace/input'
vol_zip_list = np.sort([os.path.join(input_dir,i) for i in os.listdir(input_dir) if i[-4:]=='.zip'])
# Get model
model_weight_file = 'workspace/model_weights/model_weights_quartileNormalization_echoAug.h5'
model = get_model(model_weight_file)
# Prepare the CSV where we will write a summary of all the results
results_summary_path = 'workspace/output/results_summary.csv'
region_list = ['all', 'superficial', 'deep','L', 'M', 'LA', 'LC', 'LP', 'MA', 'MC', 'MP', 'SL', 'DL', 'SM', 'DM','SLA', 'SLC', 'SLP', 'SMA', 'SMC', 'SMP', 'DLA', 'DLC', 'DLP', 'DMA', 'DMC', 'DMP']
if os.path.exists(results_summary_path):
## If a summary file already exists, don't analyze images you've already analyzed
previous_summary = np.genfromtxt(results_summary_path, delimiter=',', invalid_raise=False,dtype='str')
previously_analyzed_images = previous_summary[1:,0]
zip_base_names = np.array([os.path.basename(i) for i in vol_zip_list])
vol_zip_list = vol_zip_list[~np.isin(zip_base_names, previously_analyzed_images)]
## and keep adding to the summary file you already have started
output_file = open(results_summary_path, 'a+')
else:
## Otherwise, make a new one
output_file = open(results_summary_path, 'w')
output_file.write('filename,'+','.join(region_list)+'\n')
# Define functions that will help us find the dicom files in the input zip directories
def get_slices(scan_dir):
'''
scan_dir = path to folder containing dicoms such that each dicom represents one slice of the image volume
'''
list_of_slices = glob.glob("{}/**".format(scan_dir),
recursive=True)
return list(filter(is_dicom, list_of_slices))
def is_dicom(fullpath):
if os.path.isdir(fullpath):
return False
_, path = os.path.split(fullpath)
path = path.lower()
if path[-4:] == ".dcm":
return True
if not "." in path:
return True
return False
# Process input zip directories
total_time = 0
for zip_num, vol_zip in enumerate(vol_zip_list):
print("Processing file %s..." % os.path.basename(vol_zip))
time1 = time.time()
new_dir_name = os.path.join('/workspace/output', os.path.splitext(os.path.basename(vol_zip))[0])
os.makedirs(new_dir_name, exist_ok=True)
dicom_sub_dir = os.path.join(new_dir_name,"dicom")
raw_extract_dir = os.path.join(new_dir_name,"raw_extract")
os.makedirs(dicom_sub_dir, exist_ok=True)
os.makedirs(raw_extract_dir, exist_ok=True)
# Unzip the image volume. If the zip file contains an inner directory, move the files out of it.
with zipfile.ZipFile(vol_zip, 'r') as zip_ref:
zip_ref.extractall(raw_extract_dir)
slice_path_list = get_slices(raw_extract_dir)
for s in slice_path_list:
shutil.copy(s,os.path.join(dicom_sub_dir, os.path.basename(s)))
shutil.rmtree(raw_extract_dir)
# Create a MESE numpy array
mese, times = assemble_4d_mese_v2(dicom_sub_dir)
# If the slices are not 384x384, resize them (the model was trained on 384x384 images from OAI)
original_shape = None
if ((mese.shape[-1] != 384) or (mese.shape[-1] != 384)):
original_shape = mese.shape
mese_resized = np.zeros((mese.shape[0], mese.shape[1], 384,384))
for s in range(mese.shape[0]):
for echo in range(mese.shape[1]):
mese_resized[s,echo,:,:] = resize(mese[s,echo,:,:], (384, 384),anti_aliasing=True)
mese = mese_resized
# Whiten (i.e. normalize) the echo time of each slice that is closest to 20ms
mese_white = []
skip_flag = 0
for i,s in enumerate(mese):
if ((np.sum(times[i]==None)==0) and (len(times[i])>3)):
slice_times = times[i]
slice_20ms_idx = np.argmin(slice_times-.02)
mese_white.append(whiten_img(s[slice_20ms_idx,:,:], normalization = 'quartile'))
else:
skip_flag = skip_flag+1
if skip_flag >0:
print("Missing Echo Times. Skipping this image.")
output_file.write('%s,' % os.path.basename(vol_zip))
output_file.write('Missing Echo Times')
output_file.write('\n')
continue
mese_white = np.stack(mese_white).squeeze()
# Estimate segmentation
seg_pred = model.predict(mese_white.reshape(-1,384,384,1), batch_size = 6)
seg_pred = seg_pred.squeeze() # SAVE THIS
# Calculate T2 Map
t2 = fit_t2(mese, times, segmentation = seg_pred, n_jobs = 4, show_bad_pixels = False)
# Refine the comparison segmentation by throwing out non-physiologic T2 values
seg_pred_refined, t2_refined = t2_threshold(seg_pred, t2, t2_low=0, t2_high=100)
seg_pred_refined, t2_refined = optimal_binarize(seg_pred_refined, t2_refined, prob_threshold=0.501,voxel_count_threshold=425)
# Check to make sure that the model found some cartilage
if np.sum(seg_pred_refined)<1000:
print("Model did not find cartilage. Skipping this image.")
output_file.write('%s,' % os.path.basename(vol_zip))
output_file.write('Model did not find cartilage')
output_file.write('\n')
continue
# Project the 3D T2 map onto a 2D surface using polar coordinates
angular_bin = 5
visualization, thickness_map, min_rho_map, max_rho_map, avg_vals_dict, R = projection(t2_refined,
thickness_div = 0.5,
values_threshold = 100,
angular_bin = angular_bin,
region_stat = 'mean',
fig = False)
row_distance, column_distance = get_physical_dimensions(img_dir = dicom_sub_dir,
t2_projection = visualization,
projection_pixel_radius = R,
angular_bin = angular_bin)
# Resize the output to the size of the original input image
if original_shape is not None:
seg_pred_resized = np.zeros((original_shape[0], original_shape[2],original_shape[3]))
t2_resized = np.zeros((original_shape[0], original_shape[2],original_shape[3]))
seg_pred_refined_resized = np.zeros((original_shape[0], original_shape[2],original_shape[3]))
t2_refined_resized = np.zeros((original_shape[0], original_shape[2],original_shape[3]))
for s in range(seg_pred.shape[0]):
seg_pred_resized[s,:,:] = resize(seg_pred[s,:,:](original_shape[2], original_shape[3]), anti_aliasing=True, preserve_range=True)
t2_resized[s,:,:] = resize(t2[s,:,:], (original_shape[2],original_shape[3]),anti_aliasing=True,preserve_range=True)
seg_pred_refined_resized[s,:,:] = resize(seg_pred_refined[s,:,:], (original_shape[2], original_shape[3]), anti_aliasing=True, preserve_range=True)
t2_refined_resized[s,:,:] = resize(t2_refined[s,:,:], (original_shape[2], original_shape[3]), anti_aliasing=True, preserve_range=True)
seg_pred = 1*(seg_pred_resized>.501)
seg_pred_refined = 1*(seg_pred_refined_resized>.501)
t2 = t2_resized
t2_refined = t2_refined_resized
# Save the t2 image, segmentation, and projection results
## Save the 3D binary segmentation mask as a numpy array
seg_path = os.path.join(new_dir_name,"segmentation_mask.npy")
np.save(seg_path, seg_pred)
refined_seg_path = os.path.join(new_dir_name,"segmentation_mask_refined.npy")
np.save(refined_seg_path, seg_pred_refined)
## Save the 3D binary segmentation mask as a folder of CSV files
seg_sub_dir = os.path.join(new_dir_name,"segmentation_mask_csv")
os.makedirs(seg_sub_dir, exist_ok=True)
for i,s in enumerate(seg_pred):
slice_path = os.path.join(seg_sub_dir,str(i).zfill(3)+".csv")
np.savetxt(slice_path, s,delimiter=",", fmt='%10.5f')
seg_sub_dir = os.path.join(new_dir_name,"segmentation_mask_csv_refined")
os.makedirs(seg_sub_dir, exist_ok=True)
for i,s in enumerate(seg_pred_refined):
slice_path = os.path.join(seg_sub_dir,str(i).zfill(3)+".csv")
np.savetxt(slice_path, s,delimiter=",", fmt='%10.5f')
## Save the 3D T2 image as a numpy array
t2_img_path = os.path.join(new_dir_name,"t2.npy")
np.save(t2_img_path, t2)
t2_img_path_refined = os.path.join(new_dir_name,"t2_refined.npy")
np.save(t2_img_path_refined, t2_refined)
## Save the 3D T2 image as a folder of CSV files
t2_sub_dir = os.path.join(new_dir_name,"t2_csv")
os.makedirs(t2_sub_dir, exist_ok=True)
for i,s in enumerate(t2):
slice_path = os.path.join(t2_sub_dir,str(i).zfill(3)+".csv")
np.savetxt(slice_path, s,delimiter=",", fmt='%10.5f')
t2_sub_dir = os.path.join(new_dir_name,"t2_csv_refined")
os.makedirs(t2_sub_dir, exist_ok=True)
for i,s in enumerate(t2_refined):
slice_path = os.path.join(t2_sub_dir,str(i).zfill(3)+".csv")
np.savetxt(slice_path, s,delimiter=",", fmt='%10.5f')
## Save the 2D projection of the T2 map as a numpy array
t2_projection_path = os.path.join(new_dir_name,"t2_projection.npy")
np.save(t2_projection_path, visualization)
## Save the 2D projection of the T2 map as a csv
t2_projection_csv_path = os.path.join(new_dir_name,"t2_projection.csv")
np.savetxt(t2_projection_csv_path, visualization,delimiter=",", fmt='%10.5f')
## Save the 2D projection thickness map as a numpy array
thickness_projection_path = os.path.join(new_dir_name,"thickness_projection.npy")
np.save(thickness_projection_path, thickness_map)
## Save the 2D projection thickness map as a csv
thickness_projection_csv_path = os.path.join(new_dir_name,"thickness_projection.csv")
np.savetxt(thickness_projection_csv_path, thickness_map,delimiter=",", fmt='%10.5f')
## Save the physical dimensions of the 2D projections as a json
projection_dimensions_dict = {}
projection_dimensions_dict['row_distance(mm)'] = row_distance
projection_dimensions_dict['column_distance(mm)'] = column_distance
projection_dimensions_dict_path = os.path.join(new_dir_name,"projection_dimensions.json")
with open(projection_dimensions_dict_path, 'w') as fp:
json.dump(projection_dimensions_dict, fp)
## Save the region average T2 dictionary as a json
t2_region_json_path = os.path.join(new_dir_name,"region_mean_t2.json")
with open(t2_region_json_path, 'w') as fp:
json.dump(avg_vals_dict, fp)
# Record the average regional T2 values for this image to a summary CSV file where we're recording these metrics for all input images
output_file.write('%s,' % os.path.basename(vol_zip))
for r in region_list:
if r == 'DMP':
if np.isnan(avg_vals_dict[r]):
output_file.write('%s' % str(avg_vals_dict[r]))
else:
output_file.write('%d' % avg_vals_dict[r])
else:
if np.isnan(avg_vals_dict[r]):
output_file.write('%s,' % str(avg_vals_dict[r]))
else:
output_file.write('%d,' % avg_vals_dict[r])
output_file.write('\n')
time2 = time.time()
total_time = total_time + (time2-time1)
avg_pace = total_time / (zip_num+1)
files_remaining = len(vol_zip_list) - zip_num
print("Estimated time remaining for all images (minutes):",np.round(files_remaining*avg_pace/60,decimals=1))
output_file.close()
print()
print("Processing finished. Find results in the 'output' folder:")
print()