Coverage for src/spectroflat/fitting/line_detector.py: 100%
51 statements
« prev ^ index » next coverage.py v7.3.2, created at 2024-03-28 07:59 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2024-03-28 07:59 +0000
1from typing import Union
3import numpy as np
4from qollib.processing import simultaneous
5from qollib.ui import progress
7from .line import Line
8from .line_fit import LineFit
9from ..base import Logging
11logger = Logging.get_logger()
14class LineDetector:
15 """
16 This class aims to detect absorption (or emission) lines in an image.
18 It first bins all rows and looks for peaks to detect the approx line center.
19 This reduces noise and allows for a good first estimate where to expect line(s).
21 Then, for each anchor row, the area around the estimated center is fitted with a gaussian
22 to detect the actual peak. A map of those detected peaks is available via the `lines` variable
23 at the end of the process.
24 """
26 def __init__(self, image, line_centers: Union[list, np.array], anchors: int = 170, line_distance: int = 80):
27 #: number of anchor points to take for each line
28 self.anchors = anchors
29 #: Integer > 1 to define the minimum distance of two lines.
30 self.line_centers = np.array(line_centers)
31 #: Float to set the max error for gauss (before trying with lorentzian)
32 self.line_distance = line_distance
33 #: The image data as 2-dim matrix
34 self.image = np.array(image)
35 # list of cols to check
36 self.check_cols = []
37 #: resulting list of lines detected
38 self.lines = []
40 def run(self):
41 """
42 Detect lines at anchor points
43 """
44 self._normalize()
45 self._create_lines()
46 self._determine_cols_to_check()
47 self._detect_lines_per_col()
49 def _normalize(self):
50 self.image = self.image / np.std(self.image)
51 self.image = self.image - np.min(self.image)
53 def _create_lines(self) -> None:
54 self.lines = [self._create_line(peak) for peak in self.line_centers]
56 def _create_line(self, peak: int) -> Line:
57 return Line(peak, self.image.shape[0], rot_anker=0, line_distance=self.line_distance)
59 def _determine_cols_to_check(self) -> None:
60 dist = int(np.ceil(self.image.shape[1] / self.anchors))
61 logger.debug('Creating anchors every %s rows', dist)
62 self.check_cols = np.array([min(self.image.shape[1] - 1, dist * i) for i in range(1, self.anchors)], dtype=int)
63 self.check_cols = np.unique(self.check_cols)
65 def _detect_lines_per_col(self):
66 data = [self._line_args(line) for line in self.lines]
67 self.lines = simultaneous(_detect_line, data)
68 progress.dot(flush=True)
70 def _line_args(self, line):
71 return {'line': line, 'cols': self.check_cols, 'data': self.image, 'error': 2.1}
74def _detect_line(args):
75 """
76 Method on module level to allow parallelization.
78 :param args: Tuple with (Line, [cols to check], data)
79 """
80 error_cols = []
81 for col in args['cols']:
82 fitter = LineFit(args['line'].area(col), np.transpose(args['data'])[col][args['line'].area(col)],
83 error_threshold=args['error'])
84 try:
85 fitter.run()
86 args['line'].add((fitter.max_location, col))
87 except RuntimeError:
88 error_cols.append(col)
89 success = False if not args['line'].map else len(args['cols']) / len(args['line'].map) > 0.65
90 progress.dot(success=success)
91 return args['line']