Coverage for src/spectroflat/smile/smile_fit.py: 72%
125 statements
« prev ^ index » next coverage.py v7.3.2, created at 2024-03-28 07:59 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2024-03-28 07:59 +0000
1import copy
2import warnings
3from dataclasses import dataclass
5import numpy as np
6from scipy.interpolate import CubicSpline
7from scipy.optimize import brute
8from qollib.ui.progress import msg
10from ..base.config import SmileConfig
11from ..base.logging import Logging
12from ..utils.line_detection import find_line_cores
13from ..utils.processing import MP
15log = Logging.get_logger()
18@dataclass
19class FitRef:
20 values: np.array
21 cores: np.array = None
22 peaks: list = None
25class SmileFit:
27 def __init__(self, img: np.array, config: SmileConfig):
28 self._img = img
29 self._conf = config
30 self._state = -1
31 self._current = None
32 self.shift_map = []
33 self.chi2_map = []
35 def run(self):
36 while self._get_state():
37 log.info('Processing state %s', self._state)
38 self._process_state()
39 self._post_process_results()
41 def _get_state(self) -> bool:
42 self._state += 1
43 if len(self._img.shape) <= 2:
44 self._current = self._img
45 return True if self._state == 0 else False
47 if self._img.shape[0] > self._state:
48 self._current = self._img[self._state]
49 return True
50 return False
52 def _process_state(self):
53 sif = _SmileImgFit(self._current, self._conf).run()
54 self.shift_map.append(sif.shifts)
55 self.chi2_map.append(sif.errors)
57 def _post_process_results(self):
58 self.shift_map = np.array(self.shift_map)
59 self.chi2_map = np.array(self.chi2_map)
62class _SmileImgFit:
64 def __init__(self, img: np.array, config: SmileConfig):
65 self._img = img
66 self._conf = config
67 self.shifts = []
68 self.errors = []
69 self._ref = None
71 def run(self):
72 self._select_reference()
73 self._fit_rows()
74 self._smooth()
75 return self
77 def _fit_rows(self):
78 rows = range(self._img.shape[0])
79 res = dict(MP.simultaneous(_fit_row, [(r, self._img[r], self._ref, self._conf) for r in rows]))
80 msg(flush=True)
81 for r in rows:
82 self.shifts.append(res[r].shifts)
83 self.errors.append(res[r].error)
85 def _select_reference(self):
86 center = self._img.shape[0] // 2
87 self._ref = FitRef(values=np.average(self._img[center - 3: center + 3], axis=0))
88 peaks, cores = find_line_cores(self._ref.values, self._conf)
89 nix = [i for i, v in enumerate(cores) if v is None]
90 if nix:
91 peaks = np.delete(peaks, nix).astype(float)
92 cores = np.delete(cores, nix).astype(float)
93 self._ref.values = self._ref.values / np.mean(self._ref.values)
94 self._ref.cores = cores
95 self._ref.peaks = peaks
97 def _smooth(self):
98 if self._conf.smile_deg < 1:
99 return
100 self.shifts = np.array(self.shifts)
101 yes = np.arange(self._img.shape[0])
102 for col in range(self._img.shape[1]):
103 poly = np.polynomial.Polynomial.fit(yes, self.shifts[:, col], deg=self._conf.smile_deg)
104 self.shifts[:, col] = poly(yes)
107class _RowFit:
109 def __init__(self, row: np.array, ref: FitRef, config: SmileConfig):
110 self._row = row
111 self._ref = copy.copy(ref)
112 self._conf = config
113 self._xes = np.arange(len(self._row))
114 self._lines = []
115 self.shifts = []
116 self.error = []
118 def run(self):
119 self._find_lines()
120 self._find_fit()
121 return self
123 def _find_lines(self):
124 peaks, cores = find_line_cores(self._row, self._conf, self._ref.peaks)
125 self._lines = np.array(cores)
127 def _find_fit(self):
128 with warnings.catch_warnings():
129 # We do a brute force approach for the best deg. Thus, we can safely ignore RankWarnings.
130 warnings.filterwarnings('ignore', message='.*The fit may.*')
131 res = brute(self._chi2_error, (slice(self._conf.min_dispersion_deg, self._conf.max_dispersion_deg + 1, 1),))
132 self._chi2_error(res)
134 def _chi2_error(self, deg: tuple) -> float:
135 self._fit_shifts(int(np.round(deg)))
136 self._compute_chi2_error()
137 return float(np.sum(self.error))
139 def _fit_shifts(self, deg: int):
140 uspl = np.polynomial.Polynomial.fit(self._lines, self._ref.cores, deg)
141 # Enforce monotony here by stepping through the values and
142 # using always the maximum seen so far.
143 dispersion = np.maximum.accumulate(uspl(self._xes))
144 self.shifts = dispersion - self._xes
146 def _compute_chi2_error(self):
147 try:
148 cs = CubicSpline(self.shifts + self._xes, self._row)
149 current = cs(self._xes)
150 current = current / np.mean(current)
151 self.error = (current - self._ref.values) ** 2 / self._ref.values ** 2
152 except ValueError:
153 self.error = np.infty
156def _fit_row(args: tuple) -> tuple:
157 fr = _RowFit(args[1], args[2], args[3]).run()
158 msg(f'row {args[0]}')
159 return args[0], fr