Coverage for src/spectroflat/fitting/line_detector.py: 100%

51 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-03-28 07:59 +0000

1from typing import Union 

2 

3import numpy as np 

4from qollib.processing import simultaneous 

5from qollib.ui import progress 

6 

7from .line import Line 

8from .line_fit import LineFit 

9from ..base import Logging 

10 

11logger = Logging.get_logger() 

12 

13 

14class LineDetector: 

15 """ 

16 This class aims to detect absorption (or emission) lines in an image. 

17 

18 It first bins all rows and looks for peaks to detect the approx line center. 

19 This reduces noise and allows for a good first estimate where to expect line(s). 

20 

21 Then, for each anchor row, the area around the estimated center is fitted with a gaussian 

22 to detect the actual peak. A map of those detected peaks is available via the `lines` variable 

23 at the end of the process. 

24 """ 

25 

26 def __init__(self, image, line_centers: Union[list, np.array], anchors: int = 170, line_distance: int = 80): 

27 #: number of anchor points to take for each line 

28 self.anchors = anchors 

29 #: Integer > 1 to define the minimum distance of two lines. 

30 self.line_centers = np.array(line_centers) 

31 #: Float to set the max error for gauss (before trying with lorentzian) 

32 self.line_distance = line_distance 

33 #: The image data as 2-dim matrix 

34 self.image = np.array(image) 

35 # list of cols to check 

36 self.check_cols = [] 

37 #: resulting list of lines detected 

38 self.lines = [] 

39 

40 def run(self): 

41 """ 

42 Detect lines at anchor points 

43 """ 

44 self._normalize() 

45 self._create_lines() 

46 self._determine_cols_to_check() 

47 self._detect_lines_per_col() 

48 

49 def _normalize(self): 

50 self.image = self.image / np.std(self.image) 

51 self.image = self.image - np.min(self.image) 

52 

53 def _create_lines(self) -> None: 

54 self.lines = [self._create_line(peak) for peak in self.line_centers] 

55 

56 def _create_line(self, peak: int) -> Line: 

57 return Line(peak, self.image.shape[0], rot_anker=0, line_distance=self.line_distance) 

58 

59 def _determine_cols_to_check(self) -> None: 

60 dist = int(np.ceil(self.image.shape[1] / self.anchors)) 

61 logger.debug('Creating anchors every %s rows', dist) 

62 self.check_cols = np.array([min(self.image.shape[1] - 1, dist * i) for i in range(1, self.anchors)], dtype=int) 

63 self.check_cols = np.unique(self.check_cols) 

64 

65 def _detect_lines_per_col(self): 

66 data = [self._line_args(line) for line in self.lines] 

67 self.lines = simultaneous(_detect_line, data) 

68 progress.dot(flush=True) 

69 

70 def _line_args(self, line): 

71 return {'line': line, 'cols': self.check_cols, 'data': self.image, 'error': 2.1} 

72 

73 

74def _detect_line(args): 

75 """ 

76 Method on module level to allow parallelization. 

77 

78 :param args: Tuple with (Line, [cols to check], data) 

79 """ 

80 error_cols = [] 

81 for col in args['cols']: 

82 fitter = LineFit(args['line'].area(col), np.transpose(args['data'])[col][args['line'].area(col)], 

83 error_threshold=args['error']) 

84 try: 

85 fitter.run() 

86 args['line'].add((fitter.max_location, col)) 

87 except RuntimeError: 

88 error_cols.append(col) 

89 success = False if not args['line'].map else len(args['cols']) / len(args['line'].map) > 0.65 

90 progress.dot(success=success) 

91 return args['line']