Source code for L5NeuronSimulation.my_plotting

"""
Tool used for plotting various things throughout the project.
Helpful for navigating the BMTK h5 structuring.
"""

import matplotlib.pyplot as plt
import numpy as np
import h5py
import scipy.signal as s

[docs]def get_key(group, index=0): """From the list of the keys of the given group, returns the key at the given index. Parameters ---------- group : h5py.Group the h5py group to get the key from index : int, optional index of the key in the list of the group's keys, by default 0 Returns ------- str desired key for the h5py group """ return list(group.keys())[index]
[docs]def load_dataset(fname, groups=2): """Gets a dataset within the given h5 file. Many BMTK h5 files have one dataset within some layers of group, and this is a useful function for getting to that dataset. Assumes that each group just has one key. Parameters ---------- fname : str h5 file to load groups : int, optional number of groups before the dataset, by default 2 Returns ------- h5py.Dataset the desired dataset """ f = h5py.File(fname, 'r') for i in range(groups): f = f[get_key(f)] return f
def plot_spikes(file, show=False, id_scale=-1, id_shift = 0, time_scale = 10): data = load_dataset(file) scale = 1 if id_scale > 0: scale = id_scale / np.max(data['node_ids']) plt.plot(np.array(data['timestamps'])*time_scale,np.array(data['node_ids']) * scale + id_shift,'.') if(show): plt.show()
[docs]def plot_v(file, show=False, ms=False): """Plots the membrane potential from the given BMTK v_report.h5 file. Parameters ---------- file : str location of the h5py file show : bool, optional whether to call plt.show() at the end, by default False ms : bool, optional whether to scale x by 0.1 to get ms scale, by default False """ data = load_dataset(file) x = np.arange(0, np.array(data['data']).shape[0]) if ms: x = x / 10 plt.plot(x, data['data'][:, 0]) if(show): plt.show()
[docs]def plot_all_v(file, ms=False): """Plots each membrane potential in the given BMTK v_report.h5 file. Parameters ---------- file : str location of the h5py file ms : bool, optional whether to scale x by 0.1 to get ms scale, by default False """ data = load_dataset(file) x = np.arange(0, np.array(data['data']).shape[0]) if ms: x = x / 10 for i in range(data['data'].shape[1]): plt.plot(x, data['data'][:, i]) plt.show()
[docs]def plot_se(file, show=False): """Used to plot se_clamp_reports from BMTK. Parameters ---------- file : str location of the h5py file show : bool, optional whether to call plt.show() at the end, by default False """ data = load_dataset(file, groups=1) plt.plot(data[:, 0]) if(show): plt.show()
# def generate_spike_probs(inh_file, spike_file, time): # gamma = generate_spike_gamma(inh_file, time) # data = load_dataset(spike_file) # timestamps = np.array(data['timestamps']) # troughs = s.find_peaks(-gamma)[0] # n_parts = 10 # parts = np.zeros(n_parts) # for i in range(len(troughs) - 1): # start = troughs[i] # part_len = (troughs[i+1] - start)/n_parts # for j in range(n_parts): # parts[j] += len(np.where((timestamps >= j*part_len + start) & (timestamps < (j+1)*part_len + start))[0]) # parts = np.array(parts) / parts.sum() # #t1 = gamma[troughs[100]:troughs[101]] # t1 = gamma[troughs[0]:troughs[1]] # plt.plot(parts, label="spike probability") # plt.plot(np.arange(len(t1)) * (n_parts/len(t1)), t1/10, label="gamma ex.") # plt.legend() # plt.show() # return parts # def generate_prob_raster(inh_file, spike_file, time): # gamma = generate_spike_gamma(inh_file, time) # data = load_dataset(spike_file) # node_ids = np.array(data['node_ids']) # timestamps = np.array(data['timestamps']) # troughs = s.find_peaks(-gamma)[0] # new_ts = np.zeros(len(timestamps)) # #ids = np.arange(len(timestamps)) # cycle_num = np.zeros(len(new_ts)) # for i in range(len(troughs) - 1): # start = troughs[i] # stop = troughs[i + 1] # length = stop - start # spikes = np.where((timestamps >= start) & (timestamps < stop))[0] # cycle_num[spikes] = i # times = timestamps[spikes] # times = times - start # times = times / length # new_ts[spikes] = times # #part_len = (troughs[i+1] - start)/n_parts # # for j in range(n_parts): # # parts[j] += len(np.where((timestamps >= j*part_len + start) & (timestamps < (j+1)*part_len + start))[0]) # parts = np.zeros(10) # sep = 0.1 # for i in range(10): # parts[i] = len(np.where((new_ts >= (i * sep)) & (new_ts < ((i+1)*sep)))[0]) # #parts = np.array(parts) / parts.sum() # #t1 = gamma[troughs[100]:troughs[101]] # t1 = gamma[troughs[0]:troughs[1]] # #import pdb; pdb.set_trace() # #plt.plot(parts, label="spike probability") # plt.plot(np.arange(10)+0.5, (parts / parts.sum()), color="black", label="spike probability") # plt.plot(new_ts*(10), cycle_num/max(cycle_num), ".") # plt.xticks([0, 5, 10], labels = ["-" + r'$\pi$', 0, r'$\pi$']) # plt.axvline(x=5, ls="--", color = "black") # #plt.plot(np.arange(len(t1)) * (len(t1)/len(t1)), t1/3, label="gamma ex.") # #plt.plot(t1/3, label="gamma ex.") # plt.legend() # ax = plt.gca() # #ax.axes.xaxis.set_visible(False) # ax.axes.yaxis.set_visible(False) # plt.show() # #return parts # def plot_spike_gamma(file, time): # gamma = generate_spike_gamma(file, time) # # troughs = s.find_peaks(-smooth)[0] # # parts = np.zeros(10) # # for i in range(len(troughs) - 1): # # part_len = (troughs[i+1] - troughs[i])/10 # # for j in range(10): # # parts[j] += len(np.where((timestamps >= j*part_len) & (timestamps < (j+1)*part_len))[0]) # #import pdb; pdb.set_trace() # plt.plot(np.arange(time)*10, smooth)