Coverage for src/spectroflat/fitting/line_fit.py: 95%
73 statements
« prev ^ index » next coverage.py v7.3.2, created at 2024-03-28 07:59 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2024-03-28 07:59 +0000
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3"""
4Module to help with fitting Gauss Curves to data
6@author: hoelken
7"""
8import warnings
9import numpy as np
10from scipy.optimize import curve_fit, fminbound
11from scipy.signal import find_peaks
13from ..base import Logging
15logger = Logging.get_logger()
18class LineFit:
19 """
20 ## LineFit
21 Helper class to take care of fitting noise data of line profiles.
23 Provided a set of `x` and `y` values of same dimension the process will first look for
24 peaks in the `y`. Depending on the number of peaks the algorithm will try a single or
25 overlapping gauss fit and will compute starting amplitude, mean and sigma from the peak(s)
26 found.
28 It will first try a lorentzian fit, if this does not work it will try a gaussian fit as fallback.
30 The resulting optimized values, covariance and errors can be retrieved directly after the fit was performed.
31 Also, the x-location of the maximum (peak) is available.
32 """
34 def __init__(self, xes, yes, error_threshold=1.1):
35 #: x axis
36 self.xes = np.array(xes, dtype='float64')
37 #: y values to x axis entries
38 self.yes = np.array(yes, dtype='float64')
39 #: Float to set the max error for gauss (before checking with lorentzian)
40 self.error_threshold = error_threshold
41 # Initial values
42 self.p0_args = []
43 # Results
44 #: POPT: Optimized values for (amplitude, center, sigma) per peak.
45 #: if more than one peak is detected this will be multiple of 3 values with (a1, c1, s1, a2, s2, c2, ...)
46 self.popt = None
47 #: The estimated covariance of popt.
48 self.pcov = None
49 #: The standard deviation errors on (amplitude, center, sigma)
50 self.perr = None
51 #: the absolute max location (x)
52 self.max_location = None
53 #: Fit used
54 self.fitting_function = None
56 def run(self):
57 """
58 Trigger the fitting process.
60 ### Raises
61 `RuntimeError` if the fit was not successful
62 """
63 self._check_input()
64 self._initial_values()
65 self._fit_line()
66 self._find_max()
67 return self
69 def _check_input(self):
70 if len(self.xes) == 0 or len(self.yes) == 0:
71 raise RuntimeError('At least one of the given data sets is empty')
73 def _fit_line(self):
74 self._fit_lorentz()
75 if self.perr is not None and np.mean(self.perr) < self.error_threshold:
76 return
78 self._fit_gauss()
79 if self.perr is not None and np.mean(self.perr) < self.error_threshold:
80 return
82 raise RuntimeError('Could not fit given data. Neither Gauss nor Lorentz function worked.')
84 def _initial_values(self):
85 peaks, _ = find_peaks(self.yes, distance=len(self.yes)//3, prominence=0.05)
86 if len(peaks) == 0:
87 peaks = [np.argmax(self.yes)]
89 ymin = min(self.yes)
90 ysum = sum(self.yes)
91 for peak in peaks:
92 self.p0_args.append(self.yes[peak] - ymin) # amplitude
93 self.p0_args.append(self.xes[peak]) # center
94 self.p0_args.append(np.sqrt(sum(self.yes * (self.xes - self.xes[peak]) ** 2) / ysum)) # sigma
96 def _fit_gauss(self):
97 self.fitting_function = 'gaussian'
98 self._fit(overlapping_gaussian)
100 def _fit_lorentz(self):
101 self.fitting_function = 'lorentzian'
102 self._fit(overlapping_lorentzian)
104 def _find_max(self):
105 x0 = min(self.xes)
106 x1 = max(self.xes)
107 if self.fitting_function == 'lorentzian':
108 self.max_location = fminbound(lambda x: -overlapping_lorentzian(x, *self.popt), x0, x1)
109 else:
110 self.max_location = fminbound(lambda x: -overlapping_gaussian(x, *self.popt), x0, x1)
112 def _fit(self, func):
113 with warnings.catch_warnings():
114 warnings.filterwarnings("ignore")
115 try:
116 self.popt, self.pcov = curve_fit(func, self.xes, self.yes, p0=self.p0_args)
117 self.perr = np.sqrt(np.diag(self.pcov))
118 except (TypeError, RuntimeWarning, RuntimeError):
119 pass
122def gaussian(x, amplitude, mean, sigma) -> float:
123 """
124 Fitting function for [Gaussian normal distribution](https://en.wikipedia.org/wiki/Normal_distribution).
126 Signature follows requirements for `scipy.optimize.curve_fit` callable,
127 see [curve_fit documentation](https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.curve_fit.html).
128 It takes the independent variable as the first argument and the parameters to fit as separate remaining arguments.
130 ### Params
131 - `x` The free variable
132 - `amplitude` The amplitude
133 - `mean` The center of the peak
134 - `sigma` The standard deviation (The width of the peak)
136 ### Returns
137 The y value
138 """
139 return amplitude * np.exp(-np.power(x - mean, 2.) / (2 * np.power(sigma, 2.)))
142def overlapping_gaussian(x, *args):
143 """
144 Fitting function for data with (potentially) overlapping gaussian shaped peaks.
145 Parameters are similar to `gaussian`. Always only one x, but the other params may come in packs of three.
147 See `gaussian` for further details
148 """
149 return sum(gaussian(x, *args[i*3:(i+1)*3]) for i in range(int(len(args) / 3)))
152def lorentzian(x, amplitude, center, width) -> float:
153 """
154 Fitting function for [Cauchy-Lorentzian distribution](https://en.wikipedia.org/wiki/Cauchy_distribution)
156 Signature follows requirements for `scipy.optimize.curve_fit` callable,
157 see [curve_fit documentation](https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.curve_fit.html).
158 It takes the independent variable as the first argument and the parameters to fit as separate remaining arguments.
160 ### Params
161 - `x` The free variable
162 - `amplitude` The amplitude
163 - `center` The center of the peak
164 - `width` The width of the peak
166 ### Returns
167 The y value
168 """
169 return amplitude * width**2 / ((x-center)**2 + width**2)
172def overlapping_lorentzian(x, *args) -> float:
173 """
174 Fitting function for data with (potentially) overlapping lorentzian shaped peaks.
175 Parameters are similar to `lorentzian`. Always only one x, but the other params may come in packs of three.
177 See `lorentzian` for further details
178 """
179 return sum([lorentzian(x, *args[i*3:(i+1)*3]) for i in range(int(len(args)/3))])