Skip to content
Snippets Groups Projects
Commit 6dce7ddc authored by Olivier Bertrand's avatar Olivier Bertrand
Browse files

Correct filtering, and add 2d feature for lollipops

parent bd6c3874
No related branches found
No related tags found
No related merge requests found
...@@ -56,7 +56,7 @@ class Trajectory(pd.DataFrame): ...@@ -56,7 +56,7 @@ class Trajectory(pd.DataFrame):
columns = self.__build_columns(rotconv) columns = self.__build_columns(rotconv)
super().__init__(index=indeces, columns=columns) super().__init__(index=indeces, columns=columns)
self.__rotconv = rotconv self.__rotconv = rotconv
self.__sampling_rate = np.nan self.sampling_rate = 0
def __build_columns(self, rotconv): def __build_columns(self, rotconv):
if rotconv == 'quaternion': if rotconv == 'quaternion':
...@@ -135,10 +135,7 @@ class Trajectory(pd.DataFrame): ...@@ -135,10 +135,7 @@ class Trajectory(pd.DataFrame):
@property @property
def sampling_rate(self): def sampling_rate(self):
if np.isnan(self.__sampling_rate): return self.__sampling_rate
raise NameError('Sampling rate has not be set')
else:
return self.index.name
@sampling_rate.setter @sampling_rate.setter
def sampling_rate(self, sampling_rate): def sampling_rate(self, sampling_rate):
...@@ -579,16 +576,15 @@ the trajectory. ...@@ -579,16 +576,15 @@ the trajectory.
otherwise relative to the Nyquist frequency. Either a number or a pandas \ otherwise relative to the Nyquist frequency. Either a number or a pandas \
series. series.
""" """
if isinstance(order, [int, float]): if self.sampling_rate > 0:
order = pd.Series(data=order, index=self.columns)
if isinstance(cutoff, [int, float]):
cutoff = pd.Series(data=cutoff, index=self.columns)
if not np.isnan(self.__sampling_rate):
nyquist = self.__sampling_rate/2 nyquist = self.__sampling_rate/2
cutoff /= nyquist cutoff /= nyquist
if isinstance(order, (int, float)):
order = pd.Series(data=order, index=self.columns)
if isinstance(cutoff, (int, float)):
cutoff = pd.Series(data=cutoff, index=self.columns)
subtraj = self.consecutive_blocks() subtraj = self.consecutive_blocks()
for trajno_nan in enumerate(subtraj): for trajno_nan in subtraj:
indeces = trajno_nan.index indeces = trajno_nan.index
for col in self.columns: for col in self.columns:
b, a = signal.butter(order.loc[col], cutoff.loc[col]) b, a = signal.butter(order.loc[col], cutoff.loc[col])
...@@ -597,10 +593,16 @@ series. ...@@ -597,10 +593,16 @@ series.
if trajno_nan.shape[0] < padlen: if trajno_nan.shape[0] < padlen:
self.loc[indeces, col] *= np.nan self.loc[indeces, col] *= np.nan
else: else:
self.loc[indeces, col] = signal.filtfilt( if col[0] == 'location':
b, a, self.loc[indeces, col] = signal.filtfilt(
trajno_nan.loc[:, col], b, a,
padlen=padlen) trajno_nan.loc[:, col],
padlen=padlen).astype(float)
else:
self.loc[indeces, col] = signal.filtfilt(
b, a,
np.unwrap(trajno_nan.loc[:, col]),
padlen=padlen).astype(float)
def fillna(self, method='Cubic'): def fillna(self, method='Cubic'):
""" fillna with a given method """ fillna with a given method
...@@ -669,7 +671,8 @@ series. ...@@ -669,7 +671,8 @@ series.
colors=None, step_lollipop=1, colors=None, step_lollipop=1,
offset_lollipop=0, lollipop_marker='o', offset_lollipop=0, lollipop_marker='o',
linewidth=1, lollipop_tail_width=1, lollipop_tail_length=1, linewidth=1, lollipop_tail_width=1, lollipop_tail_length=1,
lollipop_head_size=1, stickdir='backward' lollipop_head_size=1, stickdir='backward',
plotcoords=['x', 'y', 'z']
): ):
""" lollipops plot """ lollipops plot
...@@ -699,12 +702,27 @@ series. ...@@ -699,12 +702,27 @@ series.
:param lollipop_head_size: The size of the lollipop :param lollipop_head_size: The size of the lollipop
:param stickdir: The direction of the stick of the animal \ :param stickdir: The direction of the stick of the animal \
(backward or forward) (backward or forward)
:param plotcoords: the dimension to plots, e.g. ['x','y','z'] for 3d plots \
['x','y'] for a 2d plot
""" """
# import time # import time
t_start = time.time() t_start = time.time()
if ax is None: if ax is None:
fig = plt.figure() if len(plotcoords) == 3:
ax = fig.add_subplot(111, projection='3d') fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
elif len(plotcoords) == 2:
ax = fig.add_subplot(111)
if (len(plotcoords) != 2) and (len(plotcoords) != 3):
msg = 'plotcoords need to contains 2 or 3 elements'
msg += ' for 2d and 3d plots respectively'
raise ValueError(msg)
if ax.name == '3d':
plotcoords = ['x', 'y', 'z']
elif len(plotcoords) > 2:
plotcoords = plotcoords[:2]
# Start computing for direction
direction = self.facing_direction() direction = self.facing_direction()
if colors is None: if colors is None:
timeseries = pd.Series(data=self.index, timeseries = pd.Series(data=self.index,
...@@ -721,9 +739,9 @@ series. ...@@ -721,9 +739,9 @@ series.
[[0], [[0],
['x', 'y', 'z']])) ['x', 'y', 'z']]))
if stickdir == 'forward': if stickdir == 'forward':
tailmarker.loc[0, 'x'] = -lollipop_tail_length
else:
tailmarker.loc[0, 'x'] = lollipop_tail_length tailmarker.loc[0, 'x'] = lollipop_tail_length
else:
tailmarker.loc[0, 'x'] = -lollipop_tail_length
tail = self.world2body(tailmarker, indeces=indeces) tail = self.world2body(tailmarker, indeces=indeces)
tail = tail.loc[:, 0] tail = tail.loc[:, 0]
# Plot the agent trajectory # Plot the agent trajectory
...@@ -735,6 +753,10 @@ series. ...@@ -735,6 +753,10 @@ series.
z = self.loc[:, ('location', 'z')] z = self.loc[:, ('location', 'z')]
print(time.time() - t_start) print(time.time() - t_start)
t_start = time.time() t_start = time.time()
line = dict()
line['x'] = self.x
line['y'] = self.y
line['z'] = self.z
if isinstance(colors, pd.DataFrame): if isinstance(colors, pd.DataFrame):
# Each segment will be plotted with a different color # Each segment will be plotted with a different color
# we therefore need to loop through all points # we therefore need to loop through all points
...@@ -748,14 +770,27 @@ series. ...@@ -748,14 +770,27 @@ series.
color = [colors.r[frame_i], colors.g[frame_i], color = [colors.r[frame_i], colors.g[frame_i],
colors.b[frame_i], colors.a[frame_i]] colors.b[frame_i], colors.a[frame_i]]
# Create the line to plot # Create the line to plot
line_x = [x[frame_i], x[frame_j]] line['x'] = [x[frame_i], x[frame_j]]
line_y = [y[frame_i], y[frame_j]] line['y'] = [y[frame_i], y[frame_j]]
line_z = [z[frame_i], z[frame_j]] line['z'] = [z[frame_i], z[frame_j]]
# Actual plot command # Actual plot command
ax.plot(xs=line_x, ys=line_y, zs=line_z, if len(plotcoords) == 3:
color=color, linewidth=linewidth) ax.plot(xs=line['x'], ys=line['y'], zs=line['z'],
color=color, linewidth=linewidth)
else:
# len(plotcoords) == 2 because check earlier
ax.plot(line[plotcoords[0]], line[plotcoords[1]],
color=color, linewidth=linewidth)
else: else:
ax.plot(xs=x, ys=y, zs=z, color=colors, linewidth=linewidth) # Actual plot command
if len(plotcoords) == 3:
ax.plot(xs=line['x'], ys=line['y'], zs=line['z'],
color=colors, linewidth=linewidth)
else:
# len(plotcoords) == 2 because check earlier
ax.plot(line[plotcoords[0]], line[plotcoords[1]],
color=colors, linewidth=linewidth)
print(time.time() - t_start) print(time.time() - t_start)
t_start = time.time() t_start = time.time()
# Plot the lollipop # Plot the lollipop
...@@ -777,19 +812,29 @@ series. ...@@ -777,19 +812,29 @@ series.
else: else:
color = colors color = colors
# Create the line to plot # Create the line to plot
line_x = [self.x[frame_i], line['x'] = [self.x[frame_i],
tail.x[frame_i]] tail.x[frame_i]]
line_y = [self.y[frame_i], line['y'] = [self.y[frame_i],
tail.y[frame_i]] tail.y[frame_i]]
line_z = [self.z[frame_i], line['z'] = [self.z[frame_i],
tail.z[frame_i]] tail.z[frame_i]]
# Actual plot command # Actual plot command
ax.plot(xs=line_x, ys=line_y, zs=line_z, if len(plotcoords) == 3:
color=color, linewidth=lollipop_tail_width) ax.plot(xs=line['x'], ys=line['y'], zs=line['z'],
ax.plot(xs=[line_x[0]], color=color, linewidth=lollipop_tail_width)
ys=[line_y[0]], ax.plot(xs=[line['x'][0]],
zs=[line_z[0]], ys=[line['y'][0]],
color=color, zs=[line['z'][0]],
marker=lollipop_marker, color=color,
markersize=lollipop_head_size) marker=lollipop_marker,
markersize=lollipop_head_size)
else:
# len(plotcoords) == 2 because check earlier
ax.plot(line[plotcoords[0]], line[plotcoords[1]],
color=color, linewidth=lollipop_tail_width)
ax.plot([line[plotcoords[0]][0]],
[line[plotcoords[1]][0]],
color=color,
marker=lollipop_marker,
markersize=lollipop_head_size)
print(time.time() - t_start) print(time.time() - t_start)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment