Coverage for src/spectroflat/smile/interpolated_correction.py: 62%
45 statements
« prev ^ index » next coverage.py v7.3.2, created at 2024-03-28 07:59 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2024-03-28 07:59 +0000
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3"""
4provides SmileInterpolator
6@author: hoelken
7"""
8import numpy as np
9from scipy.interpolate import CubicSpline
10from qollib.processing.execution import simultaneous, CPU_LIM
12from .smile_correction import SmileCorrector
13from ..base import Logging
15log = Logging.get_logger()
18class SmileInterpolator(SmileCorrector):
19 """
20 ## SmileInterpolator
22 Uses (bi cubic) 2D interpolation to de-skew the given image according to the provided `OffsetMap`.
23 """
25 def _correct_smile(self):
26 rows, cols = self._img.shape
27 rows = np.arange(rows)
28 xes = np.arange(cols)
29 args = [(r, xes, self._smap.get_offsets(r, self._mod_state), self._img[r]) for r in rows]
30 res = dict(simultaneous(SmileInterpolator.desmile_row, args, workers=min(CPU_LIM, 7)))
31 self._construct_result(res, rows)
33 def _construct_result(self, res: dict, rows: np.array):
34 self.result = np.array([res[row] for row in rows])
35 b = int(np.max(np.abs(self._smap.map))) + 1
36 self.result[:, 0:b] = self._img.mean()
37 self.result[:, -b:] = self._img.mean()
39 @staticmethod
40 def desmile_row(args: tuple) -> tuple:
41 # 0: row_id, 1: xes, 2: shifts, 3: row values
42 # return args[0], np.interp(args[1], args[1] + args[2], args[3])
43 try:
44 mshifts = SmileInterpolator._monotonic_shifts(args[1], args[2])
45 cs = CubicSpline(mshifts, args[3])
46 return args[0], cs(args[1])
47 except ValueError as e:
48 log.error('Row %s: %s', args[0], e)
49 raise e
51 @staticmethod
52 def _monotonic_shifts(xes: np.array, shifts: np.array) -> np.array:
53 length = len(shifts)
54 mask = shifts != 0
55 first_non_zero = np.where(mask.any(), mask.argmax(), 0)
56 if first_non_zero != 0:
57 shifts[0:first_non_zero] = shifts[first_non_zero]
58 val = length - np.flip(mask).argmax() - 1
59 last_non_zero = np.where(mask.any(), val, length)
60 if last_non_zero != length:
61 shifts[last_non_zero:] = np.maximum.accumulate(shifts[last_non_zero:])
62 return xes + shifts
64 @staticmethod
65 def desmile_state(data: tuple) -> tuple:
66 """
67 Applies the smile correction to all mod states simultaneously.
69 ### Params
70 data should be a tuple of:
71 0 = smap: OffsetMap,
72 1 = img: np.array, the image to correct,
73 2 = state: Union[int, None]
75 ### Returns
76 A tuple of (state, corrected_image)
77 """
78 sc = SmileInterpolator(data[0], data[1], mod_state=data[2]).run()
79 return data[2], sc.result