Coverage for src/spectroflat/smile/smile_fit.py: 72%

125 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-03-28 07:59 +0000

1import copy 

2import warnings 

3from dataclasses import dataclass 

4 

5import numpy as np 

6from scipy.interpolate import CubicSpline 

7from scipy.optimize import brute 

8from qollib.ui.progress import msg 

9 

10from ..base.config import SmileConfig 

11from ..base.logging import Logging 

12from ..utils.line_detection import find_line_cores 

13from ..utils.processing import MP 

14 

15log = Logging.get_logger() 

16 

17 

18@dataclass 

19class FitRef: 

20 values: np.array 

21 cores: np.array = None 

22 peaks: list = None 

23 

24 

25class SmileFit: 

26 

27 def __init__(self, img: np.array, config: SmileConfig): 

28 self._img = img 

29 self._conf = config 

30 self._state = -1 

31 self._current = None 

32 self.shift_map = [] 

33 self.chi2_map = [] 

34 

35 def run(self): 

36 while self._get_state(): 

37 log.info('Processing state %s', self._state) 

38 self._process_state() 

39 self._post_process_results() 

40 

41 def _get_state(self) -> bool: 

42 self._state += 1 

43 if len(self._img.shape) <= 2: 

44 self._current = self._img 

45 return True if self._state == 0 else False 

46 

47 if self._img.shape[0] > self._state: 

48 self._current = self._img[self._state] 

49 return True 

50 return False 

51 

52 def _process_state(self): 

53 sif = _SmileImgFit(self._current, self._conf).run() 

54 self.shift_map.append(sif.shifts) 

55 self.chi2_map.append(sif.errors) 

56 

57 def _post_process_results(self): 

58 self.shift_map = np.array(self.shift_map) 

59 self.chi2_map = np.array(self.chi2_map) 

60 

61 

62class _SmileImgFit: 

63 

64 def __init__(self, img: np.array, config: SmileConfig): 

65 self._img = img 

66 self._conf = config 

67 self.shifts = [] 

68 self.errors = [] 

69 self._ref = None 

70 

71 def run(self): 

72 self._select_reference() 

73 self._fit_rows() 

74 self._smooth() 

75 return self 

76 

77 def _fit_rows(self): 

78 rows = range(self._img.shape[0]) 

79 res = dict(MP.simultaneous(_fit_row, [(r, self._img[r], self._ref, self._conf) for r in rows])) 

80 msg(flush=True) 

81 for r in rows: 

82 self.shifts.append(res[r].shifts) 

83 self.errors.append(res[r].error) 

84 

85 def _select_reference(self): 

86 center = self._img.shape[0] // 2 

87 self._ref = FitRef(values=np.average(self._img[center - 3: center + 3], axis=0)) 

88 peaks, cores = find_line_cores(self._ref.values, self._conf) 

89 nix = [i for i, v in enumerate(cores) if v is None] 

90 if nix: 

91 peaks = np.delete(peaks, nix).astype(float) 

92 cores = np.delete(cores, nix).astype(float) 

93 self._ref.values = self._ref.values / np.mean(self._ref.values) 

94 self._ref.cores = cores 

95 self._ref.peaks = peaks 

96 

97 def _smooth(self): 

98 if self._conf.smile_deg < 1: 

99 return 

100 self.shifts = np.array(self.shifts) 

101 yes = np.arange(self._img.shape[0]) 

102 for col in range(self._img.shape[1]): 

103 poly = np.polynomial.Polynomial.fit(yes, self.shifts[:, col], deg=self._conf.smile_deg) 

104 self.shifts[:, col] = poly(yes) 

105 

106 

107class _RowFit: 

108 

109 def __init__(self, row: np.array, ref: FitRef, config: SmileConfig): 

110 self._row = row 

111 self._ref = copy.copy(ref) 

112 self._conf = config 

113 self._xes = np.arange(len(self._row)) 

114 self._lines = [] 

115 self.shifts = [] 

116 self.error = [] 

117 

118 def run(self): 

119 self._find_lines() 

120 self._find_fit() 

121 return self 

122 

123 def _find_lines(self): 

124 peaks, cores = find_line_cores(self._row, self._conf, self._ref.peaks) 

125 self._lines = np.array(cores) 

126 

127 def _find_fit(self): 

128 with warnings.catch_warnings(): 

129 # We do a brute force approach for the best deg. Thus, we can safely ignore RankWarnings. 

130 warnings.filterwarnings('ignore', message='.*The fit may.*') 

131 res = brute(self._chi2_error, (slice(self._conf.min_dispersion_deg, self._conf.max_dispersion_deg + 1, 1),)) 

132 self._chi2_error(res) 

133 

134 def _chi2_error(self, deg: tuple) -> float: 

135 self._fit_shifts(int(np.round(deg))) 

136 self._compute_chi2_error() 

137 return float(np.sum(self.error)) 

138 

139 def _fit_shifts(self, deg: int): 

140 uspl = np.polynomial.Polynomial.fit(self._lines, self._ref.cores, deg) 

141 # Enforce monotony here by stepping through the values and 

142 # using always the maximum seen so far. 

143 dispersion = np.maximum.accumulate(uspl(self._xes)) 

144 self.shifts = dispersion - self._xes 

145 

146 def _compute_chi2_error(self): 

147 try: 

148 cs = CubicSpline(self.shifts + self._xes, self._row) 

149 current = cs(self._xes) 

150 current = current / np.mean(current) 

151 self.error = (current - self._ref.values) ** 2 / self._ref.values ** 2 

152 except ValueError: 

153 self.error = np.infty 

154 

155 

156def _fit_row(args: tuple) -> tuple: 

157 fr = _RowFit(args[1], args[2], args[3]).run() 

158 msg(f'row {args[0]}') 

159 return args[0], fr