diff --git a/code/useful_functions.py b/code/useful_functions.py
index 15010e1..c88d006 100644
--- a/code/useful_functions.py
+++ b/code/useful_functions.py
@@ -110,6 +110,8 @@ def extract_stim_data(stimulus):
         Current EODf.
     stim_freq : float
         The total stimulus frequency (EODF+df).
+    stim_dur : float
+        The stimulus duration.
     amp_mod : float
         The current amplitude modulation.
     ny_freq : float
@@ -122,9 +124,10 @@ def extract_stim_data(stimulus):
     df = stimulus.metadata[stimulus.name]['DeltaF'][0][0]
     eodf = round(stimulus.metadata[stimulus.name]['EODf'][0][0])
     stim_freq = round(stimulus.metadata[stimulus.name]['Frequency'][0][0])
+    stim_dur = stimulus.duration
     # calculates the amplitude modulation
     amp_mod, ny_freq = AM(eodf, stim_freq)
-    return amplitude, df, eodf, stim_freq, amp_mod, ny_freq
+    return amplitude, df, eodf, stim_freq,stim_dur, amp_mod, ny_freq
 
 def find_exceeding_points(frequency, power, points, delta, threshold):
     """
@@ -286,6 +289,8 @@ def sam_data(sam):
 
     Returns
     -------
+    avg_dur : float
+        Average stimulus duarion.
     sam_amp : float
         amplitude in percent, relative to the fish amplitude.
     sam_am : float
@@ -307,19 +312,21 @@ def sam_data(sam):
     stim_freqs = []
     amp_mods = []
     ny_freqs = []
+    durations = []
     
     # get the stimuli
     stimuli = sam.stimuli
     
     # loop over the stimuli
     for stim in stimuli:
-        amplitude, df, eodf, stim_freq, amp_mod, ny_freq = extract_stim_data(stim)
+        amplitude, df, eodf, stim_freq,stim_dur, amp_mod, ny_freq = extract_stim_data(stim)
         amplitudes.append(amplitude)
         dfs.append(df)
         eodfs.append(eodf)
         stim_freqs.append(stim_freq)
         amp_mods.append(amp_mod)
         ny_freqs.append(ny_freq)
+        durations.append(stim_dur)
       
     # get the means
     sam_amp = np.mean(amplitudes)
@@ -328,7 +335,8 @@ def sam_data(sam):
     sam_eodf = np.mean(eodfs)
     sam_nyquist = np.mean(ny_freqs)
     sam_stim = np.mean(stim_freqs)
-    return sam_amp, sam_am, sam_df, sam_eodf, sam_nyquist, sam_stim
+    avg_dur = np.mean(durations)
+    return avg_dur, sam_amp, sam_am, sam_df, sam_eodf, sam_nyquist, sam_stim
 
 def spike_times(stim):
     """