diff --git a/navipy/trajectories/__init__.py b/navipy/trajectories/__init__.py index 117ed83b4287c61c9b239ee8bd8a3032f94e4a58..ce707c05322a74f1fce71d39b9219ef17b564cdf 100644 --- a/navipy/trajectories/__init__.py +++ b/navipy/trajectories/__init__.py @@ -56,7 +56,7 @@ class Trajectory(pd.DataFrame): columns = self.__build_columns(rotconv) super().__init__(index=indeces, columns=columns) self.__rotconv = rotconv - self.__sampling_rate = np.nan + self.sampling_rate = 0 def __build_columns(self, rotconv): if rotconv == 'quaternion': @@ -135,10 +135,7 @@ class Trajectory(pd.DataFrame): @property def sampling_rate(self): - if np.isnan(self.__sampling_rate): - raise NameError('Sampling rate has not be set') - else: - return self.index.name + return self.__sampling_rate @sampling_rate.setter def sampling_rate(self, sampling_rate): @@ -579,16 +576,15 @@ the trajectory. otherwise relative to the Nyquist frequency. Either a number or a pandas \ series. """ - if isinstance(order, [int, float]): - order = pd.Series(data=order, index=self.columns) - if isinstance(cutoff, [int, float]): - cutoff = pd.Series(data=cutoff, index=self.columns) - if not np.isnan(self.__sampling_rate): + if self.sampling_rate > 0: nyquist = self.__sampling_rate/2 cutoff /= nyquist - + if isinstance(order, (int, float)): + order = pd.Series(data=order, index=self.columns) + if isinstance(cutoff, (int, float)): + cutoff = pd.Series(data=cutoff, index=self.columns) subtraj = self.consecutive_blocks() - for trajno_nan in enumerate(subtraj): + for trajno_nan in subtraj: indeces = trajno_nan.index for col in self.columns: b, a = signal.butter(order.loc[col], cutoff.loc[col]) @@ -597,10 +593,16 @@ series. if trajno_nan.shape[0] < padlen: self.loc[indeces, col] *= np.nan else: - self.loc[indeces, col] = signal.filtfilt( - b, a, - trajno_nan.loc[:, col], - padlen=padlen) + if col[0] == 'location': + self.loc[indeces, col] = signal.filtfilt( + b, a, + trajno_nan.loc[:, col], + padlen=padlen).astype(float) + else: + self.loc[indeces, col] = signal.filtfilt( + b, a, + np.unwrap(trajno_nan.loc[:, col]), + padlen=padlen).astype(float) def fillna(self, method='Cubic'): """ fillna with a given method @@ -669,7 +671,8 @@ series. colors=None, step_lollipop=1, offset_lollipop=0, lollipop_marker='o', linewidth=1, lollipop_tail_width=1, lollipop_tail_length=1, - lollipop_head_size=1, stickdir='backward' + lollipop_head_size=1, stickdir='backward', + plotcoords=['x', 'y', 'z'] ): """ lollipops plot @@ -699,12 +702,27 @@ series. :param lollipop_head_size: The size of the lollipop :param stickdir: The direction of the stick of the animal \ (backward or forward) + :param plotcoords: the dimension to plots, e.g. ['x','y','z'] for 3d plots \ +['x','y'] for a 2d plot """ # import time t_start = time.time() if ax is None: - fig = plt.figure() - ax = fig.add_subplot(111, projection='3d') + if len(plotcoords) == 3: + fig = plt.figure() + ax = fig.add_subplot(111, projection='3d') + elif len(plotcoords) == 2: + ax = fig.add_subplot(111) + if (len(plotcoords) != 2) and (len(plotcoords) != 3): + msg = 'plotcoords need to contains 2 or 3 elements' + msg += ' for 2d and 3d plots respectively' + raise ValueError(msg) + if ax.name == '3d': + plotcoords = ['x', 'y', 'z'] + elif len(plotcoords) > 2: + plotcoords = plotcoords[:2] + + # Start computing for direction direction = self.facing_direction() if colors is None: timeseries = pd.Series(data=self.index, @@ -721,9 +739,9 @@ series. [[0], ['x', 'y', 'z']])) if stickdir == 'forward': - tailmarker.loc[0, 'x'] = -lollipop_tail_length - else: tailmarker.loc[0, 'x'] = lollipop_tail_length + else: + tailmarker.loc[0, 'x'] = -lollipop_tail_length tail = self.world2body(tailmarker, indeces=indeces) tail = tail.loc[:, 0] # Plot the agent trajectory @@ -735,6 +753,10 @@ series. z = self.loc[:, ('location', 'z')] print(time.time() - t_start) t_start = time.time() + line = dict() + line['x'] = self.x + line['y'] = self.y + line['z'] = self.z if isinstance(colors, pd.DataFrame): # Each segment will be plotted with a different color # we therefore need to loop through all points @@ -748,14 +770,27 @@ series. color = [colors.r[frame_i], colors.g[frame_i], colors.b[frame_i], colors.a[frame_i]] # Create the line to plot - line_x = [x[frame_i], x[frame_j]] - line_y = [y[frame_i], y[frame_j]] - line_z = [z[frame_i], z[frame_j]] + line['x'] = [x[frame_i], x[frame_j]] + line['y'] = [y[frame_i], y[frame_j]] + line['z'] = [z[frame_i], z[frame_j]] # Actual plot command - ax.plot(xs=line_x, ys=line_y, zs=line_z, - color=color, linewidth=linewidth) + if len(plotcoords) == 3: + ax.plot(xs=line['x'], ys=line['y'], zs=line['z'], + color=color, linewidth=linewidth) + else: + # len(plotcoords) == 2 because check earlier + ax.plot(line[plotcoords[0]], line[plotcoords[1]], + color=color, linewidth=linewidth) + else: - ax.plot(xs=x, ys=y, zs=z, color=colors, linewidth=linewidth) + # Actual plot command + if len(plotcoords) == 3: + ax.plot(xs=line['x'], ys=line['y'], zs=line['z'], + color=colors, linewidth=linewidth) + else: + # len(plotcoords) == 2 because check earlier + ax.plot(line[plotcoords[0]], line[plotcoords[1]], + color=colors, linewidth=linewidth) print(time.time() - t_start) t_start = time.time() # Plot the lollipop @@ -777,19 +812,29 @@ series. else: color = colors # Create the line to plot - line_x = [self.x[frame_i], - tail.x[frame_i]] - line_y = [self.y[frame_i], - tail.y[frame_i]] - line_z = [self.z[frame_i], - tail.z[frame_i]] + line['x'] = [self.x[frame_i], + tail.x[frame_i]] + line['y'] = [self.y[frame_i], + tail.y[frame_i]] + line['z'] = [self.z[frame_i], + tail.z[frame_i]] # Actual plot command - ax.plot(xs=line_x, ys=line_y, zs=line_z, - color=color, linewidth=lollipop_tail_width) - ax.plot(xs=[line_x[0]], - ys=[line_y[0]], - zs=[line_z[0]], - color=color, - marker=lollipop_marker, - markersize=lollipop_head_size) + if len(plotcoords) == 3: + ax.plot(xs=line['x'], ys=line['y'], zs=line['z'], + color=color, linewidth=lollipop_tail_width) + ax.plot(xs=[line['x'][0]], + ys=[line['y'][0]], + zs=[line['z'][0]], + color=color, + marker=lollipop_marker, + markersize=lollipop_head_size) + else: + # len(plotcoords) == 2 because check earlier + ax.plot(line[plotcoords[0]], line[plotcoords[1]], + color=color, linewidth=lollipop_tail_width) + ax.plot([line[plotcoords[0]][0]], + [line[plotcoords[1]][0]], + color=color, + marker=lollipop_marker, + markersize=lollipop_head_size) print(time.time() - t_start)