Coverage for src/spectroflat/analyzer.py: 51%

292 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-03-28 07:59 +0000

1from __future__ import annotations 

2 

3import os 

4from datetime import datetime 

5 

6from .base import Logging 

7from .base.config import Config 

8from .inspection.flat_field_report import * 

9from .inspection.offset_map import plt_map 

10from .sensor.artificial_flat import ArtificialFlat 

11from .sensor.c2c_comp import extract_c2c_comp 

12from .sensor.flat import Flat 

13from .shift.img_rotation import RotationCorrection, RotationAnalysis 

14from .smile import OffsetMap, SmileInterpolator, SmileMapGenerator 

15from .smile.smoothing import * 

16from .utils.line_detection import find_line_cores 

17from .utils.processing import MP 

18 

19log = Logging.get_logger() 

20 

21 

22class Analyser: 

23 """ 

24 The Analyzer is the smile detection library entry point. 

25 This class will generate the offset map and gain table that can be used to correct science frames. 

26 

27 If a report dir is given, a PDF report with detailed information on the generated offset map and gain 

28 table is created. 

29 """ 

30 

31 def __init__(self, cube: np.array, config: Config, report_dir: str = None): 

32 self._report = None 

33 self._out_dir = report_dir 

34 self._config = config 

35 self._orig = cube 

36 self._input = cube.copy() 

37 self._rotation = 0 

38 #: The smile-corrected data cube 

39 self.desmiled = cube 

40 #: The created pre flat (Sensor Flat) 

41 self.pre_flat = np.empty(cube.shape) 

42 #: The column 2 column pattern 

43 self.c2c_pattern = None 

44 #: The created gain table 

45 self.gain_table = np.empty(cube.shape) 

46 #: The detected offset map of the smile 

47 self.offset_map = OffsetMap() 

48 

49 def run(self) -> Analyser: 

50 """Perform the smile analysis on the given dataset""" 

51 try: 

52 self._start() 

53 self._plt_input() 

54 self._remove_c2c_comp() 

55 self._apply_sensor_flat() 

56 self._analyse() 

57 self._create_gain_table() 

58 self._plt_applied_flat() 

59 return self 

60 finally: 

61 self._cleanup() 

62 

63 def _apply_sensor_flat(self): 

64 if not self._config.apply_sensor_flat: 

65 log.info('Sensor flat skipped....') 

66 return 

67 

68 log.info('Generating and applying pre-flat.') 

69 flats = self._gen_pre_flat() 

70 self._apply_pre_flat(flats) 

71 self._report_pre_flat() 

72 

73 def _remove_c2c_comp(self): 

74 if not self._config.sensor_flat.average_column_response_map: 

75 return 

76 

77 log.info('Removing column-2-column response pattern...') 

78 roi = self._config.roi 

79 if roi is None: 

80 self.c2c_pattern = extract_c2c_comp(self._orig, (self._orig.shape[1], self._orig.shape[2])) 

81 else: 

82 img = np.array([self._orig[s][roi] for s in range(self._orig.shape[0])]) 

83 self.c2c_pattern = extract_c2c_comp(img, (self._orig.shape[1], self._orig.shape[2])) 

84 self._orig = self._orig * self.c2c_pattern 

85 

86 def _gen_pre_flat(self) -> list: 

87 states = self._orig.shape[0] 

88 return [Flat.from_frame(self._orig[i], self._config.sensor_flat) for i in range(states)] 

89 

90 def _apply_pre_flat(self, flats: list): 

91 states = self._orig.shape[0] 

92 self._orig = np.array([flats[i].correct(self._orig[i]) for i in range(states)]) 

93 self.desmiled = self._orig 

94 self.pre_flat = np.array([f.flat for f in flats]) 

95 

96 def _report_pre_flat(self): 

97 if not self._report: 

98 return 

99 

100 m, s = self.pre_flat.mean(), 3 * self.pre_flat.std() 

101 plt_state_imgs(self.pre_flat, title='Pre Flats (3-sigma)', pdf=self._report, clim=[m - s, m + s]) 

102 

103 def _analyse(self): 

104 self._derotate() 

105 self._plt_lines() 

106 self._compute_offsets() 

107 self._desmile() 

108 self._plt_image_results() 

109 

110 def _derotate(self): 

111 if self._config.smile.rotation_correction is None: 

112 self._config.smile.rotation_correction = self._rotation 

113 return 

114 

115 if self._config.smile.rotation_correction in ['h', 'horizontal', 'horizontally']: 

116 log.info('Detecting rotation according to horizontal lines') 

117 rot = [RotationAnalysis.detect_horizontal_rotation(self._input[s]) for s in range(self._input.shape[0])] 

118 self._rotation = np.mean(rot) 

119 elif self._config.smile.rotation_correction in ['v', 'vertical', 'vertically']: 

120 log.info('Detecting rotation according to vertical lines') 

121 rot = [RotationAnalysis.detect_vertical_rotation(self._input[s]) for s in range(self._input.shape[0])] 

122 self._rotation = np.mean(rot) 

123 else: 

124 self._rotation = float(self._config.smile.rotation_correction) 

125 

126 log.info('Using rotation correction with %.4f [deg]', self._rotation) 

127 img = [RotationCorrection(self.desmiled[s], self._rotation).bicubic() for s in range(self._input.shape[0])] 

128 self._config.smile.rotation_correction = self._rotation 

129 self.desmiled = np.array(img) 

130 

131 def _plt_lines(self): 

132 if self._report is None: 

133 return 

134 

135 if self._config.roi is None: 

136 center = self._orig.shape[1] // 2 

137 row = np.average(self._orig[0, center - 3: center + 3], axis=0) 

138 else: 

139 roi = self._orig[0][self._config.roi] 

140 center = roi.shape[0] // 2 

141 row = np.average(roi[center - 3: center + 3], axis=0) 

142 

143 _, lines = find_line_cores(row, self._config.smile) 

144 lines = [v for v in lines if v is not None] 

145 plt_selected_lines(row, lines, self._report) 

146 

147 def _compute_offsets(self) -> None: 

148 smg = SmileMapGenerator(self._config.smile, self.desmiled).run() 

149 self.offset_map = smg.omap 

150 if not self._config.smile.state_aware: 

151 log.info('Enforcing same offset correction on all mod states') 

152 self.offset_map.enforce_same_offsets_on_all_states() 

153 self._append_offset_plots_to_report(smg) 

154 

155 def _append_offset_plots_to_report(self, smg: SmileMapGenerator): 

156 if not self._report: 

157 return 

158 

159 # plt_deviation_from_straight(smg, pdf=self._report) 

160 if self._config.roi is None: 

161 offset = 0 

162 total_rows = smg.omap.map.shape[1] 

163 else: 

164 offset = self._config.roi[0].start 

165 total_rows = self._config.roi[0].stop - self._config.roi[0].start 

166 rows = [total_rows // 10, 

167 total_rows // 4, 

168 total_rows // 2, 

169 total_rows - total_rows // 4, 

170 total_rows - total_rows // 10] 

171 rows = tuple([r + offset for r in rows]) 

172 plt_map(self.offset_map, pdf=self._report, rows=rows, state_aware=self._config.smile.state_aware) 

173 

174 def _desmile(self): 

175 log.info('Applying smile correction...') 

176 args = [(self.offset_map, self._orig[s], s) for s in range(self._orig.shape[0])] 

177 result = dict(MP.simultaneous(SmileInterpolator.desmile_state, args)) 

178 self.desmiled = np.array([result[s] for s in range(self._orig.shape[0])]) 

179 

180 def _start(self): 

181 if self._out_dir: 

182 os.makedirs(self._out_dir, exist_ok=True) 

183 fname = f'offset_analysis_report_{datetime.now().strftime("%y%m%d_%H%M%S")}.pdf' 

184 self._report = PdfPages(os.path.join(self._out_dir, fname)) 

185 

186 def _create_gain_table(self): 

187 log.info('Creating gain table...') 

188 self._artificial_flat() 

189 self._remove_lines() 

190 self._interpolate_line_residuals() 

191 

192 def _refine_hard_flat(self) -> float: 

193 input_img = self._input * self.c2c_pattern if self.c2c_pattern is not None else self._input 

194 if self._config.roi is None: 

195 af = ArtificialFlat(self.desmiled).create().resmile(self.offset_map) 

196 hard_flat = af.remove(input_img) 

197 hard_flat = hard_flat / hard_flat.mean() 

198 else: 

199 roi = (slice(None, None), self._config.roi[0], self._config.roi[1]) 

200 af = ArtificialFlat(self.desmiled, roi=roi).create().resmile(self.offset_map).pad(self._input.shape) 

201 hard_flat = af.remove(input_img) 

202 temp = np.ones(input_img.shape) 

203 for state in range(temp.shape[0]): 

204 hf = hard_flat[state][self._config.roi] 

205 temp[state][self._config.roi] = hf / hf.mean() 

206 hard_flat = temp 

207 std = np.std(hard_flat / hard_flat.mean() - self.pre_flat / self.pre_flat.mean()) 

208 log.info('\tSTD of normalized diff of consecutive flats: %e', std) 

209 # hard_flat[hard_flat <= 0] = 1 

210 self.pre_flat = hard_flat 

211 return float(std) 

212 

213 def _correct_original(self) -> None: 

214 flat = self.pre_flat / self.c2c_pattern if self.c2c_pattern is not None else self.pre_flat 

215 flat = Flat(flat) 

216 current = flat.correct(self._input) 

217 args = [(self.offset_map, current[s], s) for s in range(current.shape[0])] 

218 result = dict(MP.simultaneous(SmileInterpolator.desmile_state, args, workers=4)) 

219 self.desmiled = np.array([result[s] for s in range(current.shape[0])]) 

220 

221 def _refine_results(self) -> float: 

222 # Refines the hard flat and applies the correction to create a new 

223 # de-smiled image. 

224 # Returns the standard deviation of the difference of two consecutive hard flats. 

225 std = self._refine_hard_flat() 

226 self._correct_original() 

227 return std 

228 

229 def _artificial_flat(self): 

230 log.info('iterating hard flat') 

231 first_order = self.desmiled 

232 stds = [self._refine_results() for _ in range(self._config.sensor_flat_iterations)] 

233 self._apply_c2c_pattern() 

234 self._plot_refined_results(stds, first_order) 

235 

236 def _apply_c2c_pattern(self): 

237 if self.c2c_pattern is None: 

238 return 

239 

240 self.pre_flat = self.pre_flat / self.c2c_pattern 

241 img = Flat(self.pre_flat).correct(self._input) 

242 if self._config.roi is not None: 

243 img = np.array([img[s][self._config.roi] for s in range(img.shape[0])]) 

244 c2c_pattern = extract_c2c_comp(img, (self._orig.shape[1], self._orig.shape[2])) 

245 self.pre_flat = self.pre_flat / c2c_pattern 

246 

247 def _plot_refined_results(self, stds: list, first_order: np.array): 

248 if not self._report: 

249 return 

250 

251 plt_std_of_consecutive_hard_flats(stds, self._report) 

252 m = self.pre_flat.mean() 

253 s = 3 * self.pre_flat.std() 

254 plt_state_imgs(self.pre_flat, title='New iterated hard flat (3-sigma)', pdf=self._report, clim=[m - s, m + s]) 

255 plt_state_imgs(self.desmiled, title='Desmiled corrected input', pdf=self._report) 

256 if self._config.roi is None: 

257 roi = (0, first_order.shape[1] // 2) 

258 else: 

259 roi = (0, (self._config.roi[0].stop - self._config.roi[0].start) // 2, self._config.roi[1]) 

260 self._plt_cuts((first_order[roi], "Old"), (self.desmiled[roi], "New")) 

261 plt_spatial_comparison(self._input, self.desmiled, pdf=self._report, roi=self._config.roi) 

262 

263 def _plt_cuts(self, a: tuple, b: tuple): 

264 fig, ax = plt.subplots(nrows=2, ncols=1, figsize=A4_LANDSCAPE) 

265 fig.suptitle('State 0 spectra before and after correction') 

266 ax[0].plot(a[0], label=a[1]) 

267 ax[0].plot(b[0], label=b[1]) 

268 n = len(a[0]) // 2 

269 ax[1].plot(range(n - n // 4, n + n // 4), a[0][n - n // 4:n + n // 4]) 

270 ax[1].plot(range(n - n // 4, n + n // 4), b[0][n - n // 4:n + n // 4]) 

271 fig.legend() 

272 fig.tight_layout() 

273 self._report.savefig() 

274 plt.close() 

275 

276 def _remove_lines(self): 

277 log.info('Removing vertical lines') 

278 if self._config.roi is None: 

279 self._remove_lines_full() 

280 else: 

281 self._remove_lines_roi() 

282 self._plt_line_removal() 

283 

284 def _remove_lines_roi(self): 

285 roi = (slice(None, None), self._config.roi[0], self._config.roi[1]) 

286 lr = LineRemover(self.desmiled[roi]).run() 

287 temp = np.array([lr.result[s] / lr.result[s].mean() for s in range(lr.result.shape[0])]) 

288 self.gain_table = np.ones(self._input.shape) 

289 for state in range(self._input.shape[0]): 

290 self.gain_table[roi] = temp 

291 

292 def _remove_lines_full(self): 

293 lr = LineRemover(self.desmiled).run() 

294 self.gain_table = np.array([lr.result[s] / lr.result[s].mean() for s in range(lr.result.shape[0])]) 

295 

296 def _plt_line_removal(self): 

297 if not self._report: 

298 return 

299 

300 m = self.gain_table.mean() 

301 s = 3 * self.gain_table.std() 

302 plt_state_imgs(self.gain_table, title='Gain table after removing vertical lines (3-Sigma)', 

303 pdf=self._report, clim=[m - s, m + s]) 

304 

305 def _interpolate_line_residuals(self): 

306 if not self._config.smile.smooth: 

307 return 

308 

309 log.info('Smoothing residuals') 

310 if self._config.roi is None: 

311 self._interpolate_full() 

312 else: 

313 self._interpolate_roi() 

314 self._plt_smoothing_result() 

315 

316 def _interpolate_full(self): 

317 for s in range(self.gain_table.shape[0]): 

318 self.gain_table[s] = ResidualsRemover(self.gain_table[s]).run().img 

319 self.gain_table[s] = GaussianBlur(self.gain_table[s]).run().gain 

320 

321 def _interpolate_roi(self): 

322 roi = self._config.roi 

323 for s in range(self.gain_table.shape[0]): 

324 self.gain_table[s][roi] = ResidualsRemover(self.gain_table[s][roi]).run().img 

325 self.gain_table[s][roi] = GaussianBlur(self.gain_table[s][roi]).run().gain 

326 

327 def _cleanup(self): 

328 if self._report: 

329 self._report.close() 

330 

331 def _plt_image_results(self): 

332 if self._report is None: 

333 return 

334 plt_adjustment_comparison(self._orig[0], self.desmiled[0], self._report, roi=self._config.roi) 

335 inp = np.average(self._input, axis=0) 

336 plt_img(inp, title='Averaged input image', pdf=self._report, roi=self._config.roi) 

337 plt_img(np.average(self.desmiled, axis=0), title='Averaged input image after de-smiling', 

338 pdf=self._report, clim=[inp.min(), inp.max()]) 

339 

340 def _plt_smoothing_result(self): 

341 if not self._report: 

342 return 

343 

344 m = self.gain_table.mean() 

345 s = 3 * self.gain_table.std() 

346 plt_state_imgs(self.gain_table, title='Gain table after smoothing (3-Sigma)', 

347 pdf=self._report, clim=[m - s, m + s]) 

348 

349 def _plt_applied_flat(self): 

350 if not self._report: 

351 return 

352 

353 corrected = np.true_divide(self.desmiled.astype('float32'), self.gain_table.astype('float32'), 

354 out=self.desmiled.astype('float32'), where=self.gain_table != 0, 

355 dtype='float64') 

356 plt_state_imgs(corrected, title='Flat fielded input image', pdf=self._report) 

357 plt_spatial_comparison(self._input, corrected, pdf=self._report, roi=self._config.roi) 

358 corrected[0] = self.desmiled[0] / np.mean(self.desmiled[0]) 

359 for i in range(1, corrected.shape[0]): 

360 corrected[i] = (self.desmiled[i] / self.desmiled[i].mean()) - corrected[0] 

361 corrected[0] = np.zeros(corrected[0].shape) 

362 plt_state_imgs(corrected, 

363 title='Mod state 0 subtracted from the other de-smiled images (Normalized)', 

364 pdf=self._report) 

365 self._plt_diff_cuts(corrected) 

366 

367 def _plt_diff_cuts(self, img: np.array): 

368 fig, ax = plt.subplots(nrows=1, ncols=1, figsize=A4_LANDSCAPE) 

369 for s in range(1, img.shape[0]): 

370 ax.plot(img[s, img.shape[1] // 2], label=f'<#{s}> - <#0>') 

371 ax.grid(True) 

372 fig.suptitle("Spectral cuts of delta images") 

373 ax.legend() 

374 ax.set_ylim([-0.015, 0.015]) 

375 ax.set_xlim([0, img.shape[2]]) 

376 ax.set_xlabel(r'$\lambda$ [px]') 

377 fig.tight_layout() 

378 self._report.savefig() 

379 plt.close() 

380 

381 def _plt_input(self): 

382 if not self._report: 

383 return 

384 

385 plt_state_imgs(self._orig, title='Input data', pdf=self._report)