spatial.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. import numpy as np
  2. import scipy.ndimage as ndi
  3. from scipy import signal
  4. #from skimage.registration import phase_cross_correlation
  5. #from skimage.transform import warp_polar, rotate
  6. def bins2meters(x, y, xy_range, bin_size=0.02):
  7. x_in_m = x*bin_size - (xy_range[1] - xy_range[0])/2
  8. y_in_m = y*bin_size - (xy_range[3] - xy_range[2])/2
  9. return x_in_m, y_in_m
  10. def cart2pol(x, y):
  11. rho = np.sqrt(x**2 + y**2)
  12. phi = np.arctan2(y, x)
  13. return rho, phi
  14. def pol2cart(rho, phi):
  15. x = rho * np.cos(phi)
  16. y = rho * np.sin(phi)
  17. return (x, y)
  18. def gaussian_kernel_2D(sigma=0.1):
  19. lin_profile = np.linspace(-10, 10, 50)
  20. bump = np.exp(-sigma * lin_profile**2)
  21. bump /= np.trapz(bump) # normalize to 1
  22. return bump[:, np.newaxis] * bump[np.newaxis, :]
  23. def place_field_2D(pos, pos_firing, sampling_rate, bin_size=0.02, sigma=0.15, xy_range=None):
  24. """
  25. :param pos: x, y positions sampled at sampling_rate
  26. :param pos_firing: positions when spikes occured
  27. :param sampling_rate: sampling rate of above in Hz
  28. :param bin_size: size of the squared bin to calculate firing / occupancy,
  29. same units as in pos, e.g. meters
  30. :param sigma: standard deviation, for smoothing
  31. :param range: [xmin, xmax, ymin, ymax] - array of X and Y boundaries to limit the map
  32. :return:
  33. occupancy_map, spiking_map, firing_map, s_firing_map
  34. """
  35. x_min = xy_range[0] if xy_range is not None else pos[:, 0].min()
  36. x_max = xy_range[1] if xy_range is not None else pos[:, 0].max()
  37. y_min = xy_range[2] if xy_range is not None else pos[:, 1].min()
  38. y_max = xy_range[3] if xy_range is not None else pos[:, 1].max()
  39. x_range = x_max - x_min
  40. y_range = y_max - y_min
  41. y_bin_count = int(np.ceil(y_range / bin_size))
  42. x_bin_count = int(np.ceil(x_range / bin_size))
  43. pos_range = np.array([[x_min, x_max], [y_min, y_max]])
  44. occup, x_edges, y_edges = np.histogram2d(pos[:, 0], pos[:, 1], bins=[x_bin_count, y_bin_count], range=pos_range)
  45. occupancy_map = occup / sampling_rate
  46. # spiking map
  47. spiking_map, xs_edges, ys_edges = np.histogram2d(pos_firing[:, 0], pos_firing[:, 1], bins=[x_bin_count, y_bin_count], range=pos_range)
  48. # firing = spiking / occupancy
  49. firing_map = np.divide(spiking_map, occupancy_map, out=np.zeros_like(spiking_map, dtype=float), where=occupancy_map!=0)
  50. # apply gaussial smoothing
  51. kernel = gaussian_kernel_2D(sigma)
  52. s_firing_map = signal.convolve2d(firing_map, kernel, mode='same')
  53. occupancy_map = signal.convolve2d(occupancy_map, kernel, mode='same')
  54. return occupancy_map, spiking_map, firing_map, s_firing_map
  55. def map_stats(f_map, o_map):
  56. """
  57. f_map: 2D matrix of unit firing rate map
  58. o_map: 2D matrix of animal occupancy map
  59. """
  60. o_map_norm = o_map / o_map.sum()
  61. meanrate = np.nansum(np.nansum(np.multiply(f_map, o_map_norm)))
  62. meansquarerate = np.nansum(np.nansum(np.multiply(f_map ** 2, o_map_norm)))
  63. maxrate = np.max(f_map)
  64. sparsity = 0 if meansquarerate == 0 else meanrate**2 / meansquarerate
  65. selectivity = 0 if meanrate == 0 else maxrate / meanrate
  66. peak_FR = f_map.max()
  67. spatial_info = 0 # default value
  68. if peak_FR > 0:
  69. f_map_norm = f_map / meanrate
  70. tmp = np.multiply(o_map_norm, f_map_norm)
  71. tmp = np.multiply(tmp, np.log2(f_map_norm, where=f_map_norm > 0))
  72. spatial_info = np.nansum(tmp)
  73. return sparsity, selectivity, spatial_info, peak_FR
  74. def get_field_patches(f_map, threshold=0.5):
  75. """
  76. f_map: 2D matrix of unit firing rate map
  77. return: sorted field patches, where the largest field has index 1
  78. se also: https://exeter-data-analytics.github.io/python-data/skimage.html
  79. """
  80. patch_idxs = f_map > threshold*f_map.max()
  81. # label individual patches
  82. p_labels, _ = ndi.label(patch_idxs)
  83. # sort patches according to the patch size
  84. sort_idxs = (-1*np.bincount(p_labels.flat)[1:]).argsort()
  85. p_idxs = np.unique(p_labels)[1:]
  86. p_labels_sorted = np.zeros(p_labels.shape, dtype=np.int32)
  87. for i, idx in enumerate(sort_idxs):
  88. p_labels_sorted[p_labels == p_idxs[idx]] = i + 1
  89. return p_labels_sorted
  90. def best_match_rotation_polar(map_A, map_B):
  91. """ From
  92. https://scikit-image.org/docs/stable/auto_examples/registration/plot_register_rotation.html
  93. TODO explore log-polar transform
  94. """
  95. radius = map_A.shape[0]/2
  96. A_polar = warp_polar(map_A, radius=radius)
  97. B_polar = warp_polar(map_B, radius=radius)
  98. # this way doesn't work good
  99. #shifts, error, phasediff = phase_cross_correlation(A_polar, B_polar)
  100. ccf = signal.correlate(A_polar.T, B_polar.T, mode='same').sum(axis=0)
  101. return ccf, -(np.argmax(ccf) - 180)
  102. def best_match_rotation_pearson(map_A, map_B, delta=6):
  103. if not int(delta) == delta:
  104. raise ValueError('delta should be of type int')
  105. n = map_A.shape[0]
  106. a, b, r = n/2, n/2, int(0.93*n/2)
  107. y,x = np.ogrid[-a:n-a, -b:n-b]
  108. mask = x*x + y*y <= r*r
  109. angles_num = int(360/delta)
  110. angles = np.linspace(0, (angles_num - 1) * delta, angles_num) # in degrees
  111. corrs = np.zeros(len(angles))
  112. for i, alpha in enumerate(angles):
  113. corr_2D = np.corrcoef(map_A, ndi.rotate(map_B, alpha, reshape=False))
  114. #print(corr_2D.shape, mask.shape)
  115. #corrs[i] = corr_2D[mask].mean()
  116. corrs[i] = corr_2D.mean()
  117. # normalize before correlation?
  118. return angles, corrs, np.argmax(corrs)*6 # actual value
  119. def get_positions_relative_to(pos_alo, HD, poi):
  120. """
  121. pos_allo 2D array of alloentric animal positions X, Y (shape Nx2)
  122. HD vector of HD angles (in rad.) for each pos_allo (shape N)
  123. poi point of interest (x, y) relative to which to compute egocentric coords.
  124. """
  125. pos_poi = poi - pos_alo
  126. R = np.linalg.norm(pos_poi, axis=1) # distance from animal to the POI
  127. phi_alo = (np.degrees(np.arctan2(pos_poi[:, 1], pos_poi[:, 0])) - HD) % 360 # angle to POI in allocentric frame, in deg.
  128. phi_ego = phi_alo - np.rad2deg(HD) # angle to POI in egocentric frame, in deg.
  129. phi_ego = np.deg2rad(phi_ego)
  130. return np.array([np.multiply(R, np.cos(phi_ego)), np.multiply(R, np.sin(phi_ego))]).T