diff --git a/navipy/trajectories/__init__.py b/navipy/trajectories/__init__.py index 82b97feb8b1023bd85397fcdb395d7444c85f74e..117ed83b4287c61c9b239ee8bd8a3032f94e4a58 100644 --- a/navipy/trajectories/__init__.py +++ b/navipy/trajectories/__init__.py @@ -15,6 +15,8 @@ from mpl_toolkits.mplot3d import Axes3D # noqa F401 from multiprocessing import Pool from functools import partial import time +from scipy import signal +from scipy.interpolate import CubicSpline def _markers2position(x, kwargs): @@ -250,7 +252,9 @@ class Trajectory(pd.DataFrame): def q_3(self, q_3): self.__set_q_i(3, q_3) - # overload of save/load function + # ------------------------------------------- + # ---------------- IO ----------------------- + # ------------------------------------------- def read_csv(self, filename, sep=',', header=[0, 1], index_col=0): """ Load from a hdf file """ @@ -272,8 +276,10 @@ class Trajectory(pd.DataFrame): df = pd.DataFrame(self) df.to_csv(filename) - # initialisation from variables - def from_array(self, nparray, rotconv): + # ------------------------------------------- + # ---------------- INITS FROM VAR------------ + # ------------------------------------------- + def from_array(self, nparray, rotconv, indeces=None): """ Assign trajectory from a numpy array N x 6 (rotconv = Euler angles) N x 7 (rotconv = quaternion) @@ -283,7 +289,17 @@ class Trajectory(pd.DataFrame): msg = 'nparray should be a np.ndarray and not {}' msg = msg.format(type(nparray)) raise TypeError(msg) - indeces = np.arange(0, nparray.shape[0]) + if indeces is None: + indeces = np.arange(0, nparray.shape[0]) + if not isinstance(indeces, np.ndarray): + msg = 'indeces should be a np.ndarray and not {}' + msg = msg.format(type(indeces)) + raise TypeError(msg) + if indeces.shape[0] != nparray.shape[0]: + msg = 'indeces and nparray should have same number of rows' + msg += '{}!={}' + msg = msg.format(indeces.shape[0], nparray.shape[0]) + raise TypeError(msg) if rotconv == 'quaternion': if nparray.shape[1] != 7: msg = 'nparray should have size Nx7 and not {}' @@ -463,6 +479,10 @@ class Trajectory(pd.DataFrame): self.loc[index_i, self.rotation_mode] = orientation return self + # ----------------------------------------------- + # ---------------- TRANSFORM -------------------- + # ----------------------------------------------- + def world2body(self, markers, indeces=None): """ Transform markers in world coordinate to body coordinate """ @@ -543,6 +563,108 @@ class Trajectory(pd.DataFrame): 'dalpha_2']] = rot.squeeze() return velocity + # -------------------------------------------- + # ---------------- FILTER -------------------- + # -------------------------------------------- + def filtfilt(self, order, cutoff, padlen=None): + """ + Filter the trajectory with order and cutoff by + using a lowpass filter twice (forward and backward) + to correct for phase shift + + :param order: the order of the lowpass filter. Either a number \ +or a pandas series. The series should be multiindexed as the columns of \ +the trajectory. + :param cutoff: cut off frequency in Hz if sampling rate is known\ +otherwise relative to the Nyquist frequency. Either a number or a pandas \ +series. + """ + if isinstance(order, [int, float]): + order = pd.Series(data=order, index=self.columns) + if isinstance(cutoff, [int, float]): + cutoff = pd.Series(data=cutoff, index=self.columns) + if not np.isnan(self.__sampling_rate): + nyquist = self.__sampling_rate/2 + cutoff /= nyquist + + subtraj = self.consecutive_blocks() + for trajno_nan in enumerate(subtraj): + indeces = trajno_nan.index + for col in self.columns: + b, a = signal.butter(order.loc[col], cutoff.loc[col]) + if padlen is None: + padlen = 3*max(len(a), len(b)) + if trajno_nan.shape[0] < padlen: + self.loc[indeces, col] *= np.nan + else: + self.loc[indeces, col] = signal.filtfilt( + b, a, + trajno_nan.loc[:, col], + padlen=padlen) + + def fillna(self, method='Cubic'): + """ fillna with a given method + """ + customs_method = ['Cubic'] + if not (method in customs_method): + # fall back to pandas fillna function + return self.fillna(method) + # Start implementing customs_method + if method == 'Cubic': + for col in self.loc[:, 'location'].columns: + values = self.loc[:, ('location', col)] + validtime = values.dropna().index + validvalues = values.dropna().values + cs = CubicSpline(validtime, validvalues) + time = self.index + self.loc[:, ('location', col)] = cs(time) + # for the angles we first do a ffill and then + # unwrap and interpolate on the unwrap angles + rotconv = self.rotation_mode + for col in self.loc[:, rotconv].columns: + values = self.loc[:, (rotconv, col)] + validtime = values.dropna().index + unwrapvalues = np.unwrap(values.fillna(method='ffill')) + validvalues = unwrapvalues[validtime] + cs = CubicSpline(validtime, validvalues) + time = self.index + self.loc[:, (rotconv, col)] = cs(time) + return self + + else: + msg = 'Method {} is not supported.' + msg += 'please use method supported by pd.fillna' + msg += ' or one of the following methods {}' + msg = msg.format(method, customs_method) + raise NameError(msg) + + # -------------------------------------------- + # ---------------- EXTRACT ------------------- + # -------------------------------------------- + + def consecutive_blocks(self): + """ Return a list of subtrajectory withtout nans + """ + # get a numpy array from the trajectory, + # because we are using numpy arrays later + np_traj = self.values + np_traj = np.hstack([self.index[:, np.newaxis], np_traj]) + # Look for row containing at least one nan + nonans = np.any(np.isnan(np_traj), axis=1) + # spliting the trajectory according to nan location + events = np.split(np_traj, np.where(nonans)[0]) + + # removing NaN entries + events = [ev[~np.any(np.isnan(ev), axis=1)] + for ev in events if isinstance(ev, np.ndarray)] + # removing empty DataFrames + subtraj = [Trajectory().from_dataframe(self.loc[ev[:, 0]]) + for ev in events if ev.size > 0] + return subtraj + # ------------------------------------------- + # ---------------- PLOTS -------------------- + # ------------------------------------------- + def lollipops(self, ax=None, colors=None, step_lollipop=1, offset_lollipop=0, lollipop_marker='o',