Coverage for src/spectroflat/analyzer.py: 51%
292 statements
« prev ^ index » next coverage.py v7.3.2, created at 2024-03-28 07:59 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2024-03-28 07:59 +0000
1from __future__ import annotations
3import os
4from datetime import datetime
6from .base import Logging
7from .base.config import Config
8from .inspection.flat_field_report import *
9from .inspection.offset_map import plt_map
10from .sensor.artificial_flat import ArtificialFlat
11from .sensor.c2c_comp import extract_c2c_comp
12from .sensor.flat import Flat
13from .shift.img_rotation import RotationCorrection, RotationAnalysis
14from .smile import OffsetMap, SmileInterpolator, SmileMapGenerator
15from .smile.smoothing import *
16from .utils.line_detection import find_line_cores
17from .utils.processing import MP
19log = Logging.get_logger()
22class Analyser:
23 """
24 The Analyzer is the smile detection library entry point.
25 This class will generate the offset map and gain table that can be used to correct science frames.
27 If a report dir is given, a PDF report with detailed information on the generated offset map and gain
28 table is created.
29 """
31 def __init__(self, cube: np.array, config: Config, report_dir: str = None):
32 self._report = None
33 self._out_dir = report_dir
34 self._config = config
35 self._orig = cube
36 self._input = cube.copy()
37 self._rotation = 0
38 #: The smile-corrected data cube
39 self.desmiled = cube
40 #: The created pre flat (Sensor Flat)
41 self.pre_flat = np.empty(cube.shape)
42 #: The column 2 column pattern
43 self.c2c_pattern = None
44 #: The created gain table
45 self.gain_table = np.empty(cube.shape)
46 #: The detected offset map of the smile
47 self.offset_map = OffsetMap()
49 def run(self) -> Analyser:
50 """Perform the smile analysis on the given dataset"""
51 try:
52 self._start()
53 self._plt_input()
54 self._remove_c2c_comp()
55 self._apply_sensor_flat()
56 self._analyse()
57 self._create_gain_table()
58 self._plt_applied_flat()
59 return self
60 finally:
61 self._cleanup()
63 def _apply_sensor_flat(self):
64 if not self._config.apply_sensor_flat:
65 log.info('Sensor flat skipped....')
66 return
68 log.info('Generating and applying pre-flat.')
69 flats = self._gen_pre_flat()
70 self._apply_pre_flat(flats)
71 self._report_pre_flat()
73 def _remove_c2c_comp(self):
74 if not self._config.sensor_flat.average_column_response_map:
75 return
77 log.info('Removing column-2-column response pattern...')
78 roi = self._config.roi
79 if roi is None:
80 self.c2c_pattern = extract_c2c_comp(self._orig, (self._orig.shape[1], self._orig.shape[2]))
81 else:
82 img = np.array([self._orig[s][roi] for s in range(self._orig.shape[0])])
83 self.c2c_pattern = extract_c2c_comp(img, (self._orig.shape[1], self._orig.shape[2]))
84 self._orig = self._orig * self.c2c_pattern
86 def _gen_pre_flat(self) -> list:
87 states = self._orig.shape[0]
88 return [Flat.from_frame(self._orig[i], self._config.sensor_flat) for i in range(states)]
90 def _apply_pre_flat(self, flats: list):
91 states = self._orig.shape[0]
92 self._orig = np.array([flats[i].correct(self._orig[i]) for i in range(states)])
93 self.desmiled = self._orig
94 self.pre_flat = np.array([f.flat for f in flats])
96 def _report_pre_flat(self):
97 if not self._report:
98 return
100 m, s = self.pre_flat.mean(), 3 * self.pre_flat.std()
101 plt_state_imgs(self.pre_flat, title='Pre Flats (3-sigma)', pdf=self._report, clim=[m - s, m + s])
103 def _analyse(self):
104 self._derotate()
105 self._plt_lines()
106 self._compute_offsets()
107 self._desmile()
108 self._plt_image_results()
110 def _derotate(self):
111 if self._config.smile.rotation_correction is None:
112 self._config.smile.rotation_correction = self._rotation
113 return
115 if self._config.smile.rotation_correction in ['h', 'horizontal', 'horizontally']:
116 log.info('Detecting rotation according to horizontal lines')
117 rot = [RotationAnalysis.detect_horizontal_rotation(self._input[s]) for s in range(self._input.shape[0])]
118 self._rotation = np.mean(rot)
119 elif self._config.smile.rotation_correction in ['v', 'vertical', 'vertically']:
120 log.info('Detecting rotation according to vertical lines')
121 rot = [RotationAnalysis.detect_vertical_rotation(self._input[s]) for s in range(self._input.shape[0])]
122 self._rotation = np.mean(rot)
123 else:
124 self._rotation = float(self._config.smile.rotation_correction)
126 log.info('Using rotation correction with %.4f [deg]', self._rotation)
127 img = [RotationCorrection(self.desmiled[s], self._rotation).bicubic() for s in range(self._input.shape[0])]
128 self._config.smile.rotation_correction = self._rotation
129 self.desmiled = np.array(img)
131 def _plt_lines(self):
132 if self._report is None:
133 return
135 if self._config.roi is None:
136 center = self._orig.shape[1] // 2
137 row = np.average(self._orig[0, center - 3: center + 3], axis=0)
138 else:
139 roi = self._orig[0][self._config.roi]
140 center = roi.shape[0] // 2
141 row = np.average(roi[center - 3: center + 3], axis=0)
143 _, lines = find_line_cores(row, self._config.smile)
144 lines = [v for v in lines if v is not None]
145 plt_selected_lines(row, lines, self._report)
147 def _compute_offsets(self) -> None:
148 smg = SmileMapGenerator(self._config.smile, self.desmiled).run()
149 self.offset_map = smg.omap
150 if not self._config.smile.state_aware:
151 log.info('Enforcing same offset correction on all mod states')
152 self.offset_map.enforce_same_offsets_on_all_states()
153 self._append_offset_plots_to_report(smg)
155 def _append_offset_plots_to_report(self, smg: SmileMapGenerator):
156 if not self._report:
157 return
159 # plt_deviation_from_straight(smg, pdf=self._report)
160 if self._config.roi is None:
161 offset = 0
162 total_rows = smg.omap.map.shape[1]
163 else:
164 offset = self._config.roi[0].start
165 total_rows = self._config.roi[0].stop - self._config.roi[0].start
166 rows = [total_rows // 10,
167 total_rows // 4,
168 total_rows // 2,
169 total_rows - total_rows // 4,
170 total_rows - total_rows // 10]
171 rows = tuple([r + offset for r in rows])
172 plt_map(self.offset_map, pdf=self._report, rows=rows, state_aware=self._config.smile.state_aware)
174 def _desmile(self):
175 log.info('Applying smile correction...')
176 args = [(self.offset_map, self._orig[s], s) for s in range(self._orig.shape[0])]
177 result = dict(MP.simultaneous(SmileInterpolator.desmile_state, args))
178 self.desmiled = np.array([result[s] for s in range(self._orig.shape[0])])
180 def _start(self):
181 if self._out_dir:
182 os.makedirs(self._out_dir, exist_ok=True)
183 fname = f'offset_analysis_report_{datetime.now().strftime("%y%m%d_%H%M%S")}.pdf'
184 self._report = PdfPages(os.path.join(self._out_dir, fname))
186 def _create_gain_table(self):
187 log.info('Creating gain table...')
188 self._artificial_flat()
189 self._remove_lines()
190 self._interpolate_line_residuals()
192 def _refine_hard_flat(self) -> float:
193 input_img = self._input * self.c2c_pattern if self.c2c_pattern is not None else self._input
194 if self._config.roi is None:
195 af = ArtificialFlat(self.desmiled).create().resmile(self.offset_map)
196 hard_flat = af.remove(input_img)
197 hard_flat = hard_flat / hard_flat.mean()
198 else:
199 roi = (slice(None, None), self._config.roi[0], self._config.roi[1])
200 af = ArtificialFlat(self.desmiled, roi=roi).create().resmile(self.offset_map).pad(self._input.shape)
201 hard_flat = af.remove(input_img)
202 temp = np.ones(input_img.shape)
203 for state in range(temp.shape[0]):
204 hf = hard_flat[state][self._config.roi]
205 temp[state][self._config.roi] = hf / hf.mean()
206 hard_flat = temp
207 std = np.std(hard_flat / hard_flat.mean() - self.pre_flat / self.pre_flat.mean())
208 log.info('\tSTD of normalized diff of consecutive flats: %e', std)
209 # hard_flat[hard_flat <= 0] = 1
210 self.pre_flat = hard_flat
211 return float(std)
213 def _correct_original(self) -> None:
214 flat = self.pre_flat / self.c2c_pattern if self.c2c_pattern is not None else self.pre_flat
215 flat = Flat(flat)
216 current = flat.correct(self._input)
217 args = [(self.offset_map, current[s], s) for s in range(current.shape[0])]
218 result = dict(MP.simultaneous(SmileInterpolator.desmile_state, args, workers=4))
219 self.desmiled = np.array([result[s] for s in range(current.shape[0])])
221 def _refine_results(self) -> float:
222 # Refines the hard flat and applies the correction to create a new
223 # de-smiled image.
224 # Returns the standard deviation of the difference of two consecutive hard flats.
225 std = self._refine_hard_flat()
226 self._correct_original()
227 return std
229 def _artificial_flat(self):
230 log.info('iterating hard flat')
231 first_order = self.desmiled
232 stds = [self._refine_results() for _ in range(self._config.sensor_flat_iterations)]
233 self._apply_c2c_pattern()
234 self._plot_refined_results(stds, first_order)
236 def _apply_c2c_pattern(self):
237 if self.c2c_pattern is None:
238 return
240 self.pre_flat = self.pre_flat / self.c2c_pattern
241 img = Flat(self.pre_flat).correct(self._input)
242 if self._config.roi is not None:
243 img = np.array([img[s][self._config.roi] for s in range(img.shape[0])])
244 c2c_pattern = extract_c2c_comp(img, (self._orig.shape[1], self._orig.shape[2]))
245 self.pre_flat = self.pre_flat / c2c_pattern
247 def _plot_refined_results(self, stds: list, first_order: np.array):
248 if not self._report:
249 return
251 plt_std_of_consecutive_hard_flats(stds, self._report)
252 m = self.pre_flat.mean()
253 s = 3 * self.pre_flat.std()
254 plt_state_imgs(self.pre_flat, title='New iterated hard flat (3-sigma)', pdf=self._report, clim=[m - s, m + s])
255 plt_state_imgs(self.desmiled, title='Desmiled corrected input', pdf=self._report)
256 if self._config.roi is None:
257 roi = (0, first_order.shape[1] // 2)
258 else:
259 roi = (0, (self._config.roi[0].stop - self._config.roi[0].start) // 2, self._config.roi[1])
260 self._plt_cuts((first_order[roi], "Old"), (self.desmiled[roi], "New"))
261 plt_spatial_comparison(self._input, self.desmiled, pdf=self._report, roi=self._config.roi)
263 def _plt_cuts(self, a: tuple, b: tuple):
264 fig, ax = plt.subplots(nrows=2, ncols=1, figsize=A4_LANDSCAPE)
265 fig.suptitle('State 0 spectra before and after correction')
266 ax[0].plot(a[0], label=a[1])
267 ax[0].plot(b[0], label=b[1])
268 n = len(a[0]) // 2
269 ax[1].plot(range(n - n // 4, n + n // 4), a[0][n - n // 4:n + n // 4])
270 ax[1].plot(range(n - n // 4, n + n // 4), b[0][n - n // 4:n + n // 4])
271 fig.legend()
272 fig.tight_layout()
273 self._report.savefig()
274 plt.close()
276 def _remove_lines(self):
277 log.info('Removing vertical lines')
278 if self._config.roi is None:
279 self._remove_lines_full()
280 else:
281 self._remove_lines_roi()
282 self._plt_line_removal()
284 def _remove_lines_roi(self):
285 roi = (slice(None, None), self._config.roi[0], self._config.roi[1])
286 lr = LineRemover(self.desmiled[roi]).run()
287 temp = np.array([lr.result[s] / lr.result[s].mean() for s in range(lr.result.shape[0])])
288 self.gain_table = np.ones(self._input.shape)
289 for state in range(self._input.shape[0]):
290 self.gain_table[roi] = temp
292 def _remove_lines_full(self):
293 lr = LineRemover(self.desmiled).run()
294 self.gain_table = np.array([lr.result[s] / lr.result[s].mean() for s in range(lr.result.shape[0])])
296 def _plt_line_removal(self):
297 if not self._report:
298 return
300 m = self.gain_table.mean()
301 s = 3 * self.gain_table.std()
302 plt_state_imgs(self.gain_table, title='Gain table after removing vertical lines (3-Sigma)',
303 pdf=self._report, clim=[m - s, m + s])
305 def _interpolate_line_residuals(self):
306 if not self._config.smile.smooth:
307 return
309 log.info('Smoothing residuals')
310 if self._config.roi is None:
311 self._interpolate_full()
312 else:
313 self._interpolate_roi()
314 self._plt_smoothing_result()
316 def _interpolate_full(self):
317 for s in range(self.gain_table.shape[0]):
318 self.gain_table[s] = ResidualsRemover(self.gain_table[s]).run().img
319 self.gain_table[s] = GaussianBlur(self.gain_table[s]).run().gain
321 def _interpolate_roi(self):
322 roi = self._config.roi
323 for s in range(self.gain_table.shape[0]):
324 self.gain_table[s][roi] = ResidualsRemover(self.gain_table[s][roi]).run().img
325 self.gain_table[s][roi] = GaussianBlur(self.gain_table[s][roi]).run().gain
327 def _cleanup(self):
328 if self._report:
329 self._report.close()
331 def _plt_image_results(self):
332 if self._report is None:
333 return
334 plt_adjustment_comparison(self._orig[0], self.desmiled[0], self._report, roi=self._config.roi)
335 inp = np.average(self._input, axis=0)
336 plt_img(inp, title='Averaged input image', pdf=self._report, roi=self._config.roi)
337 plt_img(np.average(self.desmiled, axis=0), title='Averaged input image after de-smiling',
338 pdf=self._report, clim=[inp.min(), inp.max()])
340 def _plt_smoothing_result(self):
341 if not self._report:
342 return
344 m = self.gain_table.mean()
345 s = 3 * self.gain_table.std()
346 plt_state_imgs(self.gain_table, title='Gain table after smoothing (3-Sigma)',
347 pdf=self._report, clim=[m - s, m + s])
349 def _plt_applied_flat(self):
350 if not self._report:
351 return
353 corrected = np.true_divide(self.desmiled.astype('float32'), self.gain_table.astype('float32'),
354 out=self.desmiled.astype('float32'), where=self.gain_table != 0,
355 dtype='float64')
356 plt_state_imgs(corrected, title='Flat fielded input image', pdf=self._report)
357 plt_spatial_comparison(self._input, corrected, pdf=self._report, roi=self._config.roi)
358 corrected[0] = self.desmiled[0] / np.mean(self.desmiled[0])
359 for i in range(1, corrected.shape[0]):
360 corrected[i] = (self.desmiled[i] / self.desmiled[i].mean()) - corrected[0]
361 corrected[0] = np.zeros(corrected[0].shape)
362 plt_state_imgs(corrected,
363 title='Mod state 0 subtracted from the other de-smiled images (Normalized)',
364 pdf=self._report)
365 self._plt_diff_cuts(corrected)
367 def _plt_diff_cuts(self, img: np.array):
368 fig, ax = plt.subplots(nrows=1, ncols=1, figsize=A4_LANDSCAPE)
369 for s in range(1, img.shape[0]):
370 ax.plot(img[s, img.shape[1] // 2], label=f'<#{s}> - <#0>')
371 ax.grid(True)
372 fig.suptitle("Spectral cuts of delta images")
373 ax.legend()
374 ax.set_ylim([-0.015, 0.015])
375 ax.set_xlim([0, img.shape[2]])
376 ax.set_xlabel(r'$\lambda$ [px]')
377 fig.tight_layout()
378 self._report.savefig()
379 plt.close()
381 def _plt_input(self):
382 if not self._report:
383 return
385 plt_state_imgs(self._orig, title='Input data', pdf=self._report)