Skip to content
Snippets Groups Projects
Commit 6dce7ddc authored by Olivier Bertrand's avatar Olivier Bertrand
Browse files

Correct filtering, and add 2d feature for lollipops

parent bd6c3874
No related branches found
No related tags found
No related merge requests found
......@@ -56,7 +56,7 @@ class Trajectory(pd.DataFrame):
columns = self.__build_columns(rotconv)
super().__init__(index=indeces, columns=columns)
self.__rotconv = rotconv
self.__sampling_rate = np.nan
self.sampling_rate = 0
def __build_columns(self, rotconv):
if rotconv == 'quaternion':
......@@ -135,10 +135,7 @@ class Trajectory(pd.DataFrame):
@property
def sampling_rate(self):
if np.isnan(self.__sampling_rate):
raise NameError('Sampling rate has not be set')
else:
return self.index.name
return self.__sampling_rate
@sampling_rate.setter
def sampling_rate(self, sampling_rate):
......@@ -579,16 +576,15 @@ the trajectory.
otherwise relative to the Nyquist frequency. Either a number or a pandas \
series.
"""
if isinstance(order, [int, float]):
order = pd.Series(data=order, index=self.columns)
if isinstance(cutoff, [int, float]):
cutoff = pd.Series(data=cutoff, index=self.columns)
if not np.isnan(self.__sampling_rate):
if self.sampling_rate > 0:
nyquist = self.__sampling_rate/2
cutoff /= nyquist
if isinstance(order, (int, float)):
order = pd.Series(data=order, index=self.columns)
if isinstance(cutoff, (int, float)):
cutoff = pd.Series(data=cutoff, index=self.columns)
subtraj = self.consecutive_blocks()
for trajno_nan in enumerate(subtraj):
for trajno_nan in subtraj:
indeces = trajno_nan.index
for col in self.columns:
b, a = signal.butter(order.loc[col], cutoff.loc[col])
......@@ -597,10 +593,16 @@ series.
if trajno_nan.shape[0] < padlen:
self.loc[indeces, col] *= np.nan
else:
self.loc[indeces, col] = signal.filtfilt(
b, a,
trajno_nan.loc[:, col],
padlen=padlen)
if col[0] == 'location':
self.loc[indeces, col] = signal.filtfilt(
b, a,
trajno_nan.loc[:, col],
padlen=padlen).astype(float)
else:
self.loc[indeces, col] = signal.filtfilt(
b, a,
np.unwrap(trajno_nan.loc[:, col]),
padlen=padlen).astype(float)
def fillna(self, method='Cubic'):
""" fillna with a given method
......@@ -669,7 +671,8 @@ series.
colors=None, step_lollipop=1,
offset_lollipop=0, lollipop_marker='o',
linewidth=1, lollipop_tail_width=1, lollipop_tail_length=1,
lollipop_head_size=1, stickdir='backward'
lollipop_head_size=1, stickdir='backward',
plotcoords=['x', 'y', 'z']
):
""" lollipops plot
......@@ -699,12 +702,27 @@ series.
:param lollipop_head_size: The size of the lollipop
:param stickdir: The direction of the stick of the animal \
(backward or forward)
:param plotcoords: the dimension to plots, e.g. ['x','y','z'] for 3d plots \
['x','y'] for a 2d plot
"""
# import time
t_start = time.time()
if ax is None:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
if len(plotcoords) == 3:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
elif len(plotcoords) == 2:
ax = fig.add_subplot(111)
if (len(plotcoords) != 2) and (len(plotcoords) != 3):
msg = 'plotcoords need to contains 2 or 3 elements'
msg += ' for 2d and 3d plots respectively'
raise ValueError(msg)
if ax.name == '3d':
plotcoords = ['x', 'y', 'z']
elif len(plotcoords) > 2:
plotcoords = plotcoords[:2]
# Start computing for direction
direction = self.facing_direction()
if colors is None:
timeseries = pd.Series(data=self.index,
......@@ -721,9 +739,9 @@ series.
[[0],
['x', 'y', 'z']]))
if stickdir == 'forward':
tailmarker.loc[0, 'x'] = -lollipop_tail_length
else:
tailmarker.loc[0, 'x'] = lollipop_tail_length
else:
tailmarker.loc[0, 'x'] = -lollipop_tail_length
tail = self.world2body(tailmarker, indeces=indeces)
tail = tail.loc[:, 0]
# Plot the agent trajectory
......@@ -735,6 +753,10 @@ series.
z = self.loc[:, ('location', 'z')]
print(time.time() - t_start)
t_start = time.time()
line = dict()
line['x'] = self.x
line['y'] = self.y
line['z'] = self.z
if isinstance(colors, pd.DataFrame):
# Each segment will be plotted with a different color
# we therefore need to loop through all points
......@@ -748,14 +770,27 @@ series.
color = [colors.r[frame_i], colors.g[frame_i],
colors.b[frame_i], colors.a[frame_i]]
# Create the line to plot
line_x = [x[frame_i], x[frame_j]]
line_y = [y[frame_i], y[frame_j]]
line_z = [z[frame_i], z[frame_j]]
line['x'] = [x[frame_i], x[frame_j]]
line['y'] = [y[frame_i], y[frame_j]]
line['z'] = [z[frame_i], z[frame_j]]
# Actual plot command
ax.plot(xs=line_x, ys=line_y, zs=line_z,
color=color, linewidth=linewidth)
if len(plotcoords) == 3:
ax.plot(xs=line['x'], ys=line['y'], zs=line['z'],
color=color, linewidth=linewidth)
else:
# len(plotcoords) == 2 because check earlier
ax.plot(line[plotcoords[0]], line[plotcoords[1]],
color=color, linewidth=linewidth)
else:
ax.plot(xs=x, ys=y, zs=z, color=colors, linewidth=linewidth)
# Actual plot command
if len(plotcoords) == 3:
ax.plot(xs=line['x'], ys=line['y'], zs=line['z'],
color=colors, linewidth=linewidth)
else:
# len(plotcoords) == 2 because check earlier
ax.plot(line[plotcoords[0]], line[plotcoords[1]],
color=colors, linewidth=linewidth)
print(time.time() - t_start)
t_start = time.time()
# Plot the lollipop
......@@ -777,19 +812,29 @@ series.
else:
color = colors
# Create the line to plot
line_x = [self.x[frame_i],
tail.x[frame_i]]
line_y = [self.y[frame_i],
tail.y[frame_i]]
line_z = [self.z[frame_i],
tail.z[frame_i]]
line['x'] = [self.x[frame_i],
tail.x[frame_i]]
line['y'] = [self.y[frame_i],
tail.y[frame_i]]
line['z'] = [self.z[frame_i],
tail.z[frame_i]]
# Actual plot command
ax.plot(xs=line_x, ys=line_y, zs=line_z,
color=color, linewidth=lollipop_tail_width)
ax.plot(xs=[line_x[0]],
ys=[line_y[0]],
zs=[line_z[0]],
color=color,
marker=lollipop_marker,
markersize=lollipop_head_size)
if len(plotcoords) == 3:
ax.plot(xs=line['x'], ys=line['y'], zs=line['z'],
color=color, linewidth=lollipop_tail_width)
ax.plot(xs=[line['x'][0]],
ys=[line['y'][0]],
zs=[line['z'][0]],
color=color,
marker=lollipop_marker,
markersize=lollipop_head_size)
else:
# len(plotcoords) == 2 because check earlier
ax.plot(line[plotcoords[0]], line[plotcoords[1]],
color=color, linewidth=lollipop_tail_width)
ax.plot([line[plotcoords[0]][0]],
[line[plotcoords[1]][0]],
color=color,
marker=lollipop_marker,
markersize=lollipop_head_size)
print(time.time() - t_start)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment