Coverage for src/spectroflat/smile/smoothing.py: 91%
96 statements
« prev ^ index » next coverage.py v7.3.2, created at 2024-03-28 07:59 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2024-03-28 07:59 +0000
1from __future__ import annotations
3import warnings
5import numpy as np
6from scipy import signal as sig
7from scipy.ndimage import gaussian_filter
8from skimage.filters import gaussian
10from ..utils import Collections
13class LineRemover:
14 """
15 class removes the vertical features (i.e. absorption/emission lines) from the flat field image.
16 The input datacube must be de-smiled.
17 """
19 def __init__(self, cube: np.array):
20 #: The original datacube
21 self.cube = cube
22 #: The resulting Flat Field
23 self.result = []
25 def run(self) -> LineRemover:
26 """
27 Iterates over all mod states in the image cube and removes the vertical features
28 while maintaining the vertical gradient.
29 """
30 self.result = np.array([self.__remove_lines(s) for s in range(self.cube.shape[0])])
31 return self
33 def __remove_lines(self, state) -> np.array:
34 img = self.cube[state]
35 mean_spec = np.repeat(np.mean(img, axis=0, keepdims=True), img.shape[0], axis=0)
36 return img / mean_spec
39class ResidualsRemover:
40 """
41 Helper to remove vertical 1px wide line residuals by interpolating over the
42 left and right border of the peak.
44 The Residual is detected by finding peaks in the horizontal average of the
45 central region of the image.
46 """
48 def __init__(self, img: np.array, peak_threshold: float = 0.5):
49 self.img = img
50 self._threshold = peak_threshold
51 self._peaks = None
53 def run(self) -> ResidualsRemover:
54 self._smooth_outliers()
55 self._smooth_global_vertical_residuals()
56 self._smooth_local_vertical_residuals()
57 self._smooth_outliers()
58 self._re_normalize()
59 return self
61 def _smooth_outliers(self):
62 self.img = Collections.remove_sigma_outliers(self.img, s=2.8)
64 def _smooth_local_vertical_residuals(self):
65 self._find_local_vertical_reseduals()
66 self._smooth_vertical_residuals()
68 def _smooth_global_vertical_residuals(self):
69 self._find_global_vertical_reseduals()
70 self._smooth_vertical_residuals()
72 def _find_local_vertical_reseduals(self):
73 quarter = self.img.shape[0] // 4
74 one_dim = np.mean(self.img[quarter:-quarter, :], axis=0)
75 pos, _ = sig.find_peaks(one_dim, prominence=self._threshold)
76 neg, _ = sig.find_peaks(-one_dim, prominence=self._threshold)
77 self._peaks = np.concatenate([pos, neg])
79 def _find_global_vertical_reseduals(self):
80 one_dim = np.mean(self.img, axis=0)
81 pos, _ = sig.find_peaks(one_dim, prominence=self._threshold)
82 neg, _ = sig.find_peaks(-one_dim, prominence=self._threshold)
83 self._peaks = np.concatenate([pos, neg])
85 def _smooth_vertical_residuals(self):
86 for peak in self._peaks:
87 left = np.mean(self.img[:, peak - 4:peak], axis=1) / 2
88 right = np.mean(self.img[:, peak + 1:peak + 5], axis=1) / 2
89 self.img[:, peak - 1:peak + 2] = np.array([left + right for _ in range(3)]).T
91 def _re_normalize(self):
92 self.img = self.img / np.mean(self.img)
95class GlobalSmudger:
96 """
97 ## GlobalSmudger
99 This class is to be applied after the absorption lines have been removed from the
100 spectral image. As main residuals are vertically, we first generate a polynomial in
101 horizontal direction to get the remaining global gradients of the image.
102 Then, we do the same in horizontal direction. Finally, a gaussian filter is applied
103 to create a smooth gain table
105 Polynomial degree and the sigma value to ignore outliers are configurable.
106 """
108 def __init__(self, img: np.array, deg: int = 11, sigma_mask: float = 0.9):
109 self._orig = img
110 self._deg = deg
111 self._sigma_mask = sigma_mask
112 self.gain = np.empty(img.shape)
114 def run(self) -> GlobalSmudger:
115 self._setup()
116 self._fit_cols()
117 self._fit_rows()
118 self._blur()
119 return self
121 def _setup(self):
122 self._cols = range(self._orig.shape[1])
123 self._rows = range(self._orig.shape[0])
124 self.gain = Collections.remove_sigma_outliers(self._orig, self._sigma_mask)
126 def _fit_rows(self):
127 for r in self._rows:
128 self._fit_row(r, self._cols)
130 def _fit_cols(self):
131 self.gain = self.gain.T
132 for c in self._cols:
133 self._fit_row(c, self._rows)
134 self.gain = self.gain.T
136 def _fit_row(self, r: int, xes: range):
137 with warnings.catch_warnings():
138 warnings.simplefilter('ignore', np.RankWarning)
139 poly = np.poly1d(np.polyfit(xes, self.gain[r], self._deg))
140 self.gain[r] = np.array([poly(xes)])
142 def _blur(self):
143 self.gain = gaussian(self.gain)
146class GaussianBlur:
148 def __init__(self, img: np.array, kernel: int = 250, truncate: float = 1.5):
149 self._orig = img
150 self._kernel = kernel
151 self._truncate = truncate
152 self.gain = np.empty(img.shape)
154 def run(self):
155 self.gain = gaussian_filter(self._orig, sigma=self._kernel, truncate=self._truncate)
156 return self