Coverage for src/spectroflat/smile/interpolated_correction.py: 62%

45 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-03-28 07:59 +0000

1#!/usr/bin/env python3 

2# -*- coding: utf-8 -*- 

3""" 

4provides SmileInterpolator 

5 

6@author: hoelken 

7""" 

8import numpy as np 

9from scipy.interpolate import CubicSpline 

10from qollib.processing.execution import simultaneous, CPU_LIM 

11 

12from .smile_correction import SmileCorrector 

13from ..base import Logging 

14 

15log = Logging.get_logger() 

16 

17 

18class SmileInterpolator(SmileCorrector): 

19 """ 

20 ## SmileInterpolator 

21 

22 Uses (bi cubic) 2D interpolation to de-skew the given image according to the provided `OffsetMap`. 

23 """ 

24 

25 def _correct_smile(self): 

26 rows, cols = self._img.shape 

27 rows = np.arange(rows) 

28 xes = np.arange(cols) 

29 args = [(r, xes, self._smap.get_offsets(r, self._mod_state), self._img[r]) for r in rows] 

30 res = dict(simultaneous(SmileInterpolator.desmile_row, args, workers=min(CPU_LIM, 7))) 

31 self._construct_result(res, rows) 

32 

33 def _construct_result(self, res: dict, rows: np.array): 

34 self.result = np.array([res[row] for row in rows]) 

35 b = int(np.max(np.abs(self._smap.map))) + 1 

36 self.result[:, 0:b] = self._img.mean() 

37 self.result[:, -b:] = self._img.mean() 

38 

39 @staticmethod 

40 def desmile_row(args: tuple) -> tuple: 

41 # 0: row_id, 1: xes, 2: shifts, 3: row values 

42 # return args[0], np.interp(args[1], args[1] + args[2], args[3]) 

43 try: 

44 mshifts = SmileInterpolator._monotonic_shifts(args[1], args[2]) 

45 cs = CubicSpline(mshifts, args[3]) 

46 return args[0], cs(args[1]) 

47 except ValueError as e: 

48 log.error('Row %s: %s', args[0], e) 

49 raise e 

50 

51 @staticmethod 

52 def _monotonic_shifts(xes: np.array, shifts: np.array) -> np.array: 

53 length = len(shifts) 

54 mask = shifts != 0 

55 first_non_zero = np.where(mask.any(), mask.argmax(), 0) 

56 if first_non_zero != 0: 

57 shifts[0:first_non_zero] = shifts[first_non_zero] 

58 val = length - np.flip(mask).argmax() - 1 

59 last_non_zero = np.where(mask.any(), val, length) 

60 if last_non_zero != length: 

61 shifts[last_non_zero:] = np.maximum.accumulate(shifts[last_non_zero:]) 

62 return xes + shifts 

63 

64 @staticmethod 

65 def desmile_state(data: tuple) -> tuple: 

66 """ 

67 Applies the smile correction to all mod states simultaneously. 

68 

69 ### Params 

70 data should be a tuple of: 

71 0 = smap: OffsetMap, 

72 1 = img: np.array, the image to correct, 

73 2 = state: Union[int, None] 

74 

75 ### Returns 

76 A tuple of (state, corrected_image) 

77 """ 

78 sc = SmileInterpolator(data[0], data[1], mod_state=data[2]).run() 

79 return data[2], sc.result