Coverage for src/spectroflat/smile/smoothing.py: 91%

96 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-03-28 07:59 +0000

1from __future__ import annotations 

2 

3import warnings 

4 

5import numpy as np 

6from scipy import signal as sig 

7from scipy.ndimage import gaussian_filter 

8from skimage.filters import gaussian 

9 

10from ..utils import Collections 

11 

12 

13class LineRemover: 

14 """ 

15 class removes the vertical features (i.e. absorption/emission lines) from the flat field image. 

16 The input datacube must be de-smiled. 

17 """ 

18 

19 def __init__(self, cube: np.array): 

20 #: The original datacube 

21 self.cube = cube 

22 #: The resulting Flat Field 

23 self.result = [] 

24 

25 def run(self) -> LineRemover: 

26 """ 

27 Iterates over all mod states in the image cube and removes the vertical features 

28 while maintaining the vertical gradient. 

29 """ 

30 self.result = np.array([self.__remove_lines(s) for s in range(self.cube.shape[0])]) 

31 return self 

32 

33 def __remove_lines(self, state) -> np.array: 

34 img = self.cube[state] 

35 mean_spec = np.repeat(np.mean(img, axis=0, keepdims=True), img.shape[0], axis=0) 

36 return img / mean_spec 

37 

38 

39class ResidualsRemover: 

40 """ 

41 Helper to remove vertical 1px wide line residuals by interpolating over the 

42 left and right border of the peak. 

43 

44 The Residual is detected by finding peaks in the horizontal average of the 

45 central region of the image. 

46 """ 

47 

48 def __init__(self, img: np.array, peak_threshold: float = 0.5): 

49 self.img = img 

50 self._threshold = peak_threshold 

51 self._peaks = None 

52 

53 def run(self) -> ResidualsRemover: 

54 self._smooth_outliers() 

55 self._smooth_global_vertical_residuals() 

56 self._smooth_local_vertical_residuals() 

57 self._smooth_outliers() 

58 self._re_normalize() 

59 return self 

60 

61 def _smooth_outliers(self): 

62 self.img = Collections.remove_sigma_outliers(self.img, s=2.8) 

63 

64 def _smooth_local_vertical_residuals(self): 

65 self._find_local_vertical_reseduals() 

66 self._smooth_vertical_residuals() 

67 

68 def _smooth_global_vertical_residuals(self): 

69 self._find_global_vertical_reseduals() 

70 self._smooth_vertical_residuals() 

71 

72 def _find_local_vertical_reseduals(self): 

73 quarter = self.img.shape[0] // 4 

74 one_dim = np.mean(self.img[quarter:-quarter, :], axis=0) 

75 pos, _ = sig.find_peaks(one_dim, prominence=self._threshold) 

76 neg, _ = sig.find_peaks(-one_dim, prominence=self._threshold) 

77 self._peaks = np.concatenate([pos, neg]) 

78 

79 def _find_global_vertical_reseduals(self): 

80 one_dim = np.mean(self.img, axis=0) 

81 pos, _ = sig.find_peaks(one_dim, prominence=self._threshold) 

82 neg, _ = sig.find_peaks(-one_dim, prominence=self._threshold) 

83 self._peaks = np.concatenate([pos, neg]) 

84 

85 def _smooth_vertical_residuals(self): 

86 for peak in self._peaks: 

87 left = np.mean(self.img[:, peak - 4:peak], axis=1) / 2 

88 right = np.mean(self.img[:, peak + 1:peak + 5], axis=1) / 2 

89 self.img[:, peak - 1:peak + 2] = np.array([left + right for _ in range(3)]).T 

90 

91 def _re_normalize(self): 

92 self.img = self.img / np.mean(self.img) 

93 

94 

95class GlobalSmudger: 

96 """ 

97 ## GlobalSmudger 

98 

99 This class is to be applied after the absorption lines have been removed from the 

100 spectral image. As main residuals are vertically, we first generate a polynomial in 

101 horizontal direction to get the remaining global gradients of the image. 

102 Then, we do the same in horizontal direction. Finally, a gaussian filter is applied 

103 to create a smooth gain table 

104 

105 Polynomial degree and the sigma value to ignore outliers are configurable. 

106 """ 

107 

108 def __init__(self, img: np.array, deg: int = 11, sigma_mask: float = 0.9): 

109 self._orig = img 

110 self._deg = deg 

111 self._sigma_mask = sigma_mask 

112 self.gain = np.empty(img.shape) 

113 

114 def run(self) -> GlobalSmudger: 

115 self._setup() 

116 self._fit_cols() 

117 self._fit_rows() 

118 self._blur() 

119 return self 

120 

121 def _setup(self): 

122 self._cols = range(self._orig.shape[1]) 

123 self._rows = range(self._orig.shape[0]) 

124 self.gain = Collections.remove_sigma_outliers(self._orig, self._sigma_mask) 

125 

126 def _fit_rows(self): 

127 for r in self._rows: 

128 self._fit_row(r, self._cols) 

129 

130 def _fit_cols(self): 

131 self.gain = self.gain.T 

132 for c in self._cols: 

133 self._fit_row(c, self._rows) 

134 self.gain = self.gain.T 

135 

136 def _fit_row(self, r: int, xes: range): 

137 with warnings.catch_warnings(): 

138 warnings.simplefilter('ignore', np.RankWarning) 

139 poly = np.poly1d(np.polyfit(xes, self.gain[r], self._deg)) 

140 self.gain[r] = np.array([poly(xes)]) 

141 

142 def _blur(self): 

143 self.gain = gaussian(self.gain) 

144 

145 

146class GaussianBlur: 

147 

148 def __init__(self, img: np.array, kernel: int = 250, truncate: float = 1.5): 

149 self._orig = img 

150 self._kernel = kernel 

151 self._truncate = truncate 

152 self.gain = np.empty(img.shape) 

153 

154 def run(self): 

155 self.gain = gaussian_filter(self._orig, sigma=self._kernel, truncate=self._truncate) 

156 return self