Source code for pysolate.wavelet

# -*- coding: utf-8 -*-
# SPDX-License-Identifier: MIT
# LLNL-CODE-841231
# authors:
#        Ana Aguiar (aguiarmoya1@llnl.gov)
#        Andrea Chiang (andrea@llnl.gov)
#
# Wavelet transform functions are based on the MATLAB
# Synchrosqueezing Toolbox by Eugene Brevdo.
# (http://www.math.princeton.edu/~ebrevdo/)
"""
Routines for handling the wavelet transform
"""


import math
import numpy as np
import fnmatch
from scipy.integrate import quad
from obspy.core.trace import Stats
from obspy.core.util import AttribDict


def _psihfn(xi,wave_type):
    return np.conj(wfiltfn(xi, wave_type))*wfiltfn(xi, wave_type)/xi


[docs]def wfilth(wave_type, N, a): """ Fast fourier transform of the wavelet function Outputs the FFT of a given ``wave_type`` of length N at scale a. :param wave_type: wavelet function. :type wave_type: str :param N: number of samples to calculate. :type N: int :param a: wavelet scale. :type a: float :return: wavelet sampling in frequency domain :rtype: :class:`numpy.ndarray` """ k = np.arange(0, N, 1, dtype=int) xi = np.zeros(N, dtype=float) xi[:int(N/2)+1] = 2 * math.pi/N * np.arange(0,N/2+1, 1) xi[int(N/2)+1:] = 2 * math.pi/N * np.arange(-N/2+1, 0, 1) tmpxi = a*xi psih = wfiltfn(tmpxi, wave_type) # Normalizing psih = psih * math.sqrt(a) / math.sqrt(2*math.pi) # Center around zero in the time domain psih = psih * np.power((-1),k) return psih
[docs]def wfiltfn(xi, wave_type): """ Wavelet transform function of the wavelet filter in fourier domain :param xi: sampled time series. :type xi: :class:`numpy.ndarray` :param wave_type: wavelet function. :type wave_type: str :return: mother wavelet function. :rtype: :class:`numpy.ndarray` """ if wave_type == "mhat": s = 1 psihfn = -np.sqrt(8) * s**(5/2) * (np.pi**(1/4)/np.sqrt(3)) * (xi**2) * np.exp((-s**2)*(xi**2/2)) elif wave_type == "morlet": mu = 2*np.pi cs = (1 + np.exp(-mu**2) - 2*np.exp(-3/4*(mu**2)))**(-1/2) ks = np.exp(-1/2*(mu**2)) psihfn = cs*(np.pi**(-1/4)) * np.exp(-1/2*(mu-xi)**2) - ks*np.exp(-1/2*(xi**2)) elif wave_type == "shannon": pass elif wave_type == "hhat": psihfn = 2/np.sqrt(5) * np.pi**(-1/4) * xi * (1+xi) * np.exp(-1/2*xi**2) return np.array(psihfn)
[docs]def cwt(trace, wave_type, nvoices): """ Continuous wavelet transform using the wavelet function :param trace: input time series data. :type time_series: list or :class:`numpy.ndarray` :param wave_type: wavelet function. :type wave_type: str :param nvoices: sampling of CWT in scale. :type nvoices: int :returns: the wavelet transform of shape (scales, time_series), and the length vector containing the associated scales. :rtype: (:class:`numpy.ndarray`, :class:`numpy.ndarray`) """ time_series = trace.data dt = trace.stats.delta n = trace.stats.npts # Padding the signal N = int(2**(1+round(math.log2(len(time_series))+np.finfo(float).eps))) n1 = math.floor((N-n)/2) n2 = n1 if math.fmod(2*n1+n, 2) == 1: n2 = n1 + 1 time_series = np.pad(time_series, (n1,n2), mode="constant") # Choosing more than this means the wavelet window becomes too short noctaves = np.log2(N)-1 # number of octaves if noctaves <= 0 and math.fmod(noctaves,1) != 0: raise ValueError("noctaves has to be higher than zero.") if nvoices <= 0 and math.fmod(nvoices,1) != 0: raise ValueError("nvoices has to be higher than zero.") if dt <= 0: raise ValueError("dt has to be higher than zero.") if np.any(np.isnan(time_series)): raise ValueError("time_series has null values.") num_scales = int(noctaves*nvoices) mi = np.arange(1, num_scales+1, 1, dtype=int) scales = np.power(2**(1/nvoices), mi) Wx = np.zeros((num_scales, N), dtype=complex) time_series_fft = np.fft.fft(time_series) # for each octave (this part should be parallelized?) for i in range(num_scales): psih = wfilth(wave_type, N, scales[i]) # wavelet sampling Wx[i, :] = np.fft.ifftshift(np.fft.ifft(psih*time_series_fft)) # Output scales for graphing purposes, scale by dt scales = scales * dt Wxout = Wx[:,n1+1:n1+n+1] return (Wxout, scales)
[docs]def inverse_cwt(Wx, wave_type, nvoices): """ Inverse continuous wavelet tranform Reconstructs the original signal from the wavelet transform. :param Wx: wavelet transform of shape (len(scales), len(time_series)) :type Wx: :class:`numpy.ndarray` :param wave_type: wavelet function. :type wave_type: str :param nvoices: sampling of CWT in scale. :type nvoices: int :return: time series data. :rtype: :class:`numpy.ndarray` """ [num_scales, n] = np.shape(Wx) #Padding the siggnal N = int(2**(1+round(math.log2(n)+np.finfo(float).eps))) n1 = math.floor((N-n)/2) # n2 = n1 # if math.fmod(2*n1+n,2)==1: # n2 = n1 + 1 Wx_padded = np.zeros((num_scales, N), dtype=complex) Wx_padded[:,n1+1:n1+n+1] = Wx noctaves = np.log2(N) - 1 if math.fmod(noctaves,1) != 0: raise ValueError("noctaves has to be nonzero.") if nvoices <= 0 and math.fmod(nvoices,1) != 0: raise ValueError("nvoices has to be higher than zero.") mi = np.arange(1, num_scales+1, 1, dtype=int) scales = np.power(2**(1/nvoices), mi) # For type 'shannon', have to include the following if wave_type == "shannon": Cpsi = log(2) else: [Cpsi, Cpsi_err] = quad(_psihfn, 0, np.inf, args=(wave_type) ) # Normalize Cpsi = Cpsi / (4*np.pi) x = np.zeros((1, N)) # for each octave for ai in range(num_scales): Wxa = Wx_padded[ai, :] psih = wfilth(wave_type, N, scales[ai]) # Convolution theorem Wxah = np.fft.fft(Wxa) xah = Wxah * psih xa = np.fft.ifftshift(np.fft.ifft(xah)) x = x + xa/scales[ai] # Take real part and normalize by log_e(a)/Cpsi x = np.log(2**(1/nvoices))/Cpsi * np.real(x) # Keep the unpadded part time_series = x[0][n1+1:n1+n+1] return time_series
[docs]def blockbandpass(Wx, scales, scale_min, scale_max, block_threshold): """ Apply a band reject filter to modify the wavelet coefficients over a scale bandpass :param Wx: wavelet transform of shape (len(scales), len(time_series)) :type Wx: :class:`numpy.ndarray` :param scale_min: minimum time scale for bandpass blocking. :type scale_min: float :param scale_max: maximum time scale for bandpass blocking. :type scale_max: float :param block_threshhold: percent amplitude adjustment to the wavelet coefficients within ``scale_min`` and ``scale_max``. :type block_threshold: float :return: modified wavelet transform of shape (len(scales), len(time_series)). :rtype: :class:`numpy.ndarray` """ na, n = np.shape(Wx) thresh = block_threshold*0.01 a = np.ones((1,na)) a = a*[ (scales <= scale_min) | (scales >= scale_max) ] for k in range(na): if a[0,k] == 0: a[0,k] = thresh return Wx*a.T
[docs]class Wavelet(object): """ One dimensional continues wavelet transform of a time series Main class for a single wavelet transform including the station headers, wavelet coefficients and scales. Noise model calculated from the function :func:`~pysolate.threshold.noise_model` is passed through additional ``kwargs``. :param coefs: Continuous wavelet transform of a time series, the first axis corresponds to the scales and the second axis corresponds to the length of the time series. :type coefs: :class:`numpy.ndarray` :param scales: Wavelet scales :type scales: :class:`numpy.ndarray` :param headers: header information of the data. :type headers: dict, :class:`~obspy.core.trace.Stats` :param M: mean of noise model. :type M: :class:`numpy.ndarray` :param S: standard deviation of noise model. :type S: :class:`numpy.ndarray` :param P: threshold of the noise signal. :type P: :class:`numpy.ndarray` .. rubric:: Attributes ``stats`` : :class:`~obspy.core.trace.Stats` header of the wavelet transform, including station info. ``scales`` : :class:`numpy.ndarray` wavelet scales. ``coefs`` : :class:`numpy.ndarray` wavelet coefficients. ``noise_model`` : :class:`~obspy.core.util.attribdict.AttribDict` noise model """ def __init__(self, coefs=None, scales=None, headers=None, **kwargs): if headers is None: headers = {} self.stats = Stats(headers) self.scales = scales self.coefs = coefs self.noise_model = AttribDict(**kwargs)
[docs] def get_id(self): """ Returns the station ID :return: station ID containing network, station, location and channel codes. :rtype: str """ out = "%(network)s.%(station)s.%(location)s.%(channel)s" return out%(self.stats)
def __str__(self): """ Return better readable string representation of Parameter object. """ out = " | %(starttime)s - %(endtime)s | " attrs = ', '.join( [key for key in ["scales", "coefs"] if key is not None], ) if len(self.noise_model) > 0: attrs = ', '.join([attrs, "noise_model"]) return self.get_id() + out%(self.stats) + attrs def _repr_pretty_(self, p, cycle): p.text(str(self))
[docs]class WaveletCollection(object): """ A collection of wavelet objects Main class that contains wavelet transforms for multiple time series. :param wavelets: wavelet transform(s). :type wavelets: :class:`~pysolate.wavelet.Wavelet`, list """ def __init__(self, wavelets=None): self.wavelets = [] if isinstance(wavelets, Wavelet): wavelets = [wavelets] if wavelets: self.wavelets.extend(wavelets) def __len__(self): return len(self.wavelets) count = __len__ def __setitem__(self, index, wavelet): """ __setitem__ method. """ self.wavelets.__setitem__(index, wavelet) def __getitem__(self, index): """ __getitem__ method. :return: Wavelet objects """ if isinstance(index, slice): return self.__class__(wavelets=self.wavelets.__getitem__(index)) else: return self.wavelets.__getitem__(index) def __delitem__(self, index): """ Passes on the __delitem__ method to the underlying list of wavelets. """ return self.wavelets.__delitem__(index) def __getslice__(self, i, j, k=1): """ __getslice__ method. :return: WaveletCollection object """ return self.__class__(wavelets=self.wavelets[max(0, i):max(0, j):k])
[docs] def select(self, network=None, station=None, location=None, channel=None, component=None): """ Query wavelets Select wavelet transforms that matches the given station criteria. :param tag: processed data to plot. :type tag: str :param network: network code. :type network: str :param station: station code. :type station: str :param location: location code. :type location: str :param channel: channel code. :type channel: str :param component: component code. :type component: str """ # Adapted from ObsPy Stream select method wavelets = [] if component is not None and channel is not None: component = component.upper() channel = channel.upper() if (channel[-1:] not in "?*" and component not in "?*" and component != channel[-1:]): msg = "Selection criteria for channel and component are " + \ "mutually exclusive!" raise ValueError(msg) for _i in self.wavelets: if network is not None: if not fnmatch.fnmatch( _i.stats.network.upper(), network.upper() ): continue if station is not None: if not fnmatch.fnmatch( _i.stats.station.upper(), station.upper() ): continue if location is not None: if not fnmatch.fnmatch( _i.stats.location.upper(), location.upper() ): continue if channel is not None: if not fnmatch.fnmatch( _i.stats.channel.upper(), channel.upper() ): continue if component is not None: if not fnmatch.fnmatch( _i.stats.component.upper(), component.upper() ): continue wavelets.append(_i) return self.__class__(wavelets=wavelets)
def __str__(self): pretty_string = "\n".join([_i.__str__() for _i in self]) return pretty_string def _repr_pretty_(self, p, cycle): p.text(str(self))