KCSD.py 42 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059
  1. #!/usr/bin/env python
  2. """This script is used to generate Current Source Density Estimates, using the
  3. kCSD method Jan et.al (2012).
  4. This was written by :
  5. [1]Chaitanya Chintaluri,
  6. [2]Michal Czerwinski,
  7. Laboratory of Neuroinformatics,
  8. Nencki Institute of Exprimental Biology, Warsaw.
  9. KCSD1D[1][2], KCSD2D[1], KCSD3D[1], MoIKCSD[1]
  10. """
  11. from __future__ import division
  12. import numpy as np
  13. from scipy import special, integrate, interpolate
  14. from scipy.spatial import distance
  15. from numpy.linalg import LinAlgError
  16. from . import utility_functions as utils
  17. from . import basis_functions as basis
  18. skmonaco_available = False
  19. class CSD(object):
  20. """CSD - The base class for KCSD methods."""
  21. def __init__(self, ele_pos, pots):
  22. self.validate(ele_pos, pots)
  23. self.ele_pos = ele_pos
  24. self.pots = pots
  25. self.n_ele = self.ele_pos.shape[0]
  26. self.n_time = self.pots.shape[1]
  27. self.dim = self.ele_pos.shape[1]
  28. self.cv_error = None
  29. def validate(self, ele_pos, pots):
  30. """Basic checks to see if inputs are okay
  31. Parameters
  32. ----------
  33. ele_pos : numpy array
  34. positions of electrodes
  35. pots : numpy array
  36. potentials measured by electrodes
  37. """
  38. if ele_pos.shape[0] != pots.shape[0]:
  39. raise Exception("Number of measured potentials is not equal "
  40. "to electrode number!")
  41. if ele_pos.shape[0] < 1+ele_pos.shape[1]: #Dim+1
  42. raise Exception("Number of electrodes must be at least :",
  43. 1+ele_pos.shape[1])
  44. if utils.contains_duplicated_electrodes(ele_pos):
  45. raise Exception("Error! Duplicated electrode!")
  46. def sanity(self, true_csd, pos_csd):
  47. """Useful for comparing TrueCSD with reconstructed CSD. Computes, the RMS error
  48. between the true_csd and the reconstructed csd at pos_csd using the
  49. method defined.
  50. Parameters
  51. ----------
  52. true_csd : csd values used to generate potentials
  53. pos_csd : csd estimatation from the method
  54. Returns
  55. -------
  56. RMSE : root mean squared difference
  57. """
  58. csd = self.values(pos_csd)
  59. RMSE = np.sqrt(np.mean(np.square(true_csd - csd)))
  60. return RMSE
  61. class KCSD(CSD):
  62. """KCSD - The base class for all the KCSD variants.
  63. This estimates the Current Source Density, for a given configuration of
  64. electrod positions and recorded potentials, electrodes.
  65. The method implented here is based on the original paper
  66. by Jan Potworowski et.al. 2012.
  67. """
  68. def __init__(self, ele_pos, pots, **kwargs):
  69. super(KCSD, self).__init__(ele_pos, pots)
  70. self.parameters(**kwargs)
  71. self.estimate_at()
  72. self.place_basis()
  73. self.create_src_dist_tables()
  74. self.method()
  75. def parameters(self, **kwargs):
  76. """Defining the default values of the method passed as kwargs
  77. Parameters
  78. ----------
  79. **kwargs
  80. Same as those passed to initialize the Class
  81. """
  82. self.src_type = kwargs.pop('src_type', 'gauss')
  83. self.sigma = kwargs.pop('sigma', 1.0)
  84. self.h = kwargs.pop('h', 1.0)
  85. self.n_src_init = kwargs.pop('n_src_init', 1000)
  86. self.lambd = kwargs.pop('lambd', 0.0)
  87. self.R_init = kwargs.pop('R_init', 0.23)
  88. self.ext_x = kwargs.pop('ext_x', 0.0)
  89. self.xmin = kwargs.pop('xmin', np.min(self.ele_pos[:, 0]))
  90. self.xmax = kwargs.pop('xmax', np.max(self.ele_pos[:, 0]))
  91. self.gdx = kwargs.pop('gdx', 0.01*(self.xmax - self.xmin))
  92. if self.dim >= 2:
  93. self.ext_y = kwargs.pop('ext_y', 0.0)
  94. self.ymin = kwargs.pop('ymin', np.min(self.ele_pos[:, 1]))
  95. self.ymax = kwargs.pop('ymax', np.max(self.ele_pos[:, 1]))
  96. self.gdy = kwargs.pop('gdy', 0.01*(self.ymax - self.ymin))
  97. if self.dim == 3:
  98. self.ext_z = kwargs.pop('ext_z', 0.0)
  99. self.zmin = kwargs.pop('zmin', np.min(self.ele_pos[:, 2]))
  100. self.zmax = kwargs.pop('zmax', np.max(self.ele_pos[:, 2]))
  101. self.gdz = kwargs.pop('gdz', 0.01*(self.zmax - self.zmin))
  102. if kwargs:
  103. raise TypeError('Invalid keyword arguments:', kwargs.keys())
  104. def method(self):
  105. """Actual sequence of methods called for KCSD
  106. Defines:
  107. self.k_pot and self.k_interp_cross matrices
  108. Parameters
  109. ----------
  110. None
  111. """
  112. self.create_lookup() #Look up table
  113. self.update_b_pot() #update kernel
  114. self.update_b_src() #update crskernel
  115. self.update_b_interp_pot() #update pot interp
  116. def create_lookup(self, dist_table_density=20):
  117. """Creates a table for easy potential estimation from CSD.
  118. Updates and Returns the potentials due to a
  119. given basis source like a lookup
  120. table whose shape=(dist_table_density,)
  121. Parameters
  122. ----------
  123. dist_table_density : int
  124. number of distance values at which potentials are computed.
  125. Default 100
  126. """
  127. xs = np.logspace(0., np.log10(self.dist_max+1.), dist_table_density)
  128. xs = xs - 1.0 #starting from 0
  129. dist_table = np.zeros(len(xs))
  130. for i, pos in enumerate(xs):
  131. dist_table[i] = self.forward_model(pos,
  132. self.R,
  133. self.h,
  134. self.sigma,
  135. self.basis)
  136. self.interpolate_pot_at = interpolate.interp1d(xs, dist_table, kind='cubic')
  137. def update_b_pot(self):
  138. """Updates the b_pot - array is (#_basis_sources, #_electrodes)
  139. Updates the k_pot - array is (#_electrodes, #_electrodes) K(x,x')
  140. Eq9,Jan2012
  141. Calculates b_pot - matrix containing the values of all
  142. the potential basis functions in all the electrode positions
  143. (essential for calculating the cross_matrix).
  144. Parameters
  145. ----------
  146. None
  147. """
  148. self.b_pot = self.interpolate_pot_at(self.src_ele_dists)
  149. self.k_pot = np.dot(self.b_pot.T, self.b_pot) #K(x,x') Eq9,Jan2012
  150. self.k_pot /= self.n_src
  151. def update_b_src(self):
  152. """Updates the b_src in the shape of (#_est_pts, #_basis_sources)
  153. Updates the k_interp_cross - K_t(x,y) Eq17
  154. Calculate b_src - matrix containing containing the values of
  155. all the source basis functions in all the points at which we want to
  156. calculate the solution (essential for calculating the cross_matrix)
  157. Parameters
  158. ----------
  159. None
  160. """
  161. self.b_src = self.basis(self.src_estm_dists, self.R).T
  162. self.k_interp_cross = np.dot(self.b_src, self.b_pot) #K_t(x,y) Eq17
  163. self.k_interp_cross /= self.n_src
  164. def update_b_interp_pot(self):
  165. """Compute the matrix of potentials generated by every source
  166. basis function at every position in the interpolated space.
  167. Updates b_interp_pot
  168. Updates k_interp_pot
  169. Parameters
  170. ----------
  171. None
  172. """
  173. self.b_interp_pot = self.interpolate_pot_at(self.src_estm_dists).T
  174. self.k_interp_pot = np.dot(self.b_interp_pot, self.b_pot)
  175. self.k_interp_pot /= self.n_src
  176. def values(self, estimate='CSD'):
  177. """Computes the values of the quantity of interest
  178. Parameters
  179. ----------
  180. estimate : 'CSD' or 'POT'
  181. What quantity is to be estimated
  182. Defaults to 'CSD'
  183. Returns
  184. -------
  185. estimation : np.array
  186. estimated quantity of shape (ngx, ngy, ngz, nt)
  187. """
  188. if estimate == 'CSD': #Maybe used for estimating the potentials also.
  189. estimation_table = self.k_interp_cross
  190. elif estimate == 'POT':
  191. estimation_table = self.k_interp_pot
  192. else:
  193. print('Invalid quantity to be measured, pass either CSD or POT')
  194. k_inv = np.linalg.inv(self.k_pot + self.lambd *
  195. np.identity(self.k_pot.shape[0]))
  196. estimation = np.zeros((self.n_estm, self.n_time))
  197. for t in range(self.n_time):
  198. beta = np.dot(k_inv, self.pots[:, t])
  199. for i in range(self.n_ele):
  200. estimation[:, t] += estimation_table[:, i] *beta[i] # C*(x) Eq 18
  201. return self.process_estimate(estimation)
  202. def process_estimate(self, estimation):
  203. """Function used to rearrange estimation according to dimension, to be
  204. used by the fuctions values
  205. Parameters
  206. ----------
  207. estimation : np.array
  208. Returns
  209. -------
  210. estimation : np.array
  211. estimated quantity of shape (ngx, ngy, ngz, nt)
  212. """
  213. if self.dim == 1:
  214. estimation = estimation.reshape(self.ngx, self.n_time)
  215. elif self.dim == 2:
  216. estimation = estimation.reshape(self.ngx, self.ngy, self.n_time)
  217. elif self.dim == 3:
  218. estimation = estimation.reshape(self.ngx, self.ngy, self.ngz, self.n_time)
  219. return estimation
  220. def update_R(self, R):
  221. """Update the width of the basis fuction - Used in Cross validation
  222. Parameters
  223. ----------
  224. R : float
  225. """
  226. self.R = R
  227. self.dist_max = max(np.max(self.src_ele_dists),
  228. np.max(self.src_estm_dists)) + self.R
  229. self.method()
  230. def update_lambda(self, lambd):
  231. """Update the lambda parameter of regularization, Used in Cross validation
  232. Parameters
  233. ----------
  234. lambd : float
  235. """
  236. self.lambd = lambd
  237. def cross_validate(self, lambdas=None, Rs=None):
  238. """Method defines the cross validation.
  239. By default only cross_validates over lambda,
  240. When no argument is passed, it takes
  241. lambdas = np.logspace(-2,-25,25,base=10.)
  242. and Rs = np.array(self.R).flatten()
  243. otherwise pass necessary numpy arrays
  244. Parameters
  245. ----------
  246. lambdas : numpy array
  247. Rs : numpy array
  248. Returns
  249. -------
  250. R : post cross validation
  251. Lambda : post cross validation
  252. """
  253. if lambdas is None: #when None
  254. print('No lambda given, using defaults')
  255. lambdas = np.logspace(-2,-25,25,base=10.) #Default multiple lambda
  256. lambdas = np.hstack((lambdas, np.array((0.0))))
  257. elif lambdas.size == 1: #resize when one entry
  258. lambdas = lambdas.flatten()
  259. if Rs is None: #when None
  260. Rs = np.array((self.R)).flatten() #Default over one R value
  261. errs = np.zeros((Rs.size, lambdas.size))
  262. index_generator = []
  263. for ii in range(self.n_ele):
  264. idx_test = [ii]
  265. idx_train = list(range(self.n_ele))
  266. idx_train.remove(ii) #Leave one out
  267. index_generator.append((idx_train, idx_test))
  268. for R_idx,R in enumerate(Rs): #Iterate over R
  269. self.update_R(R)
  270. print('Cross validating R (all lambda) :', R)
  271. for lambd_idx,lambd in enumerate(lambdas): #Iterate over lambdas
  272. errs[R_idx, lambd_idx] = self.compute_cverror(lambd,
  273. index_generator)
  274. err_idx = np.where(errs==np.min(errs)) #Index of the least error
  275. cv_R = Rs[err_idx[0]][0] #First occurance of the least error's
  276. cv_lambda = lambdas[err_idx[1]][0]
  277. self.cv_error = np.min(errs) #otherwise is None
  278. self.update_R(cv_R) #Update solver
  279. self.update_lambda(cv_lambda)
  280. print('R, lambda :', cv_R, cv_lambda)
  281. return cv_R, cv_lambda
  282. def compute_cverror(self, lambd, index_generator):
  283. """Useful for Cross validation error calculations
  284. Parameters
  285. ----------
  286. lambd : float
  287. index_generator : list
  288. Returns
  289. -------
  290. err : float
  291. the sum of the error computed.
  292. """
  293. err = 0
  294. for idx_train, idx_test in index_generator:
  295. B_train = self.k_pot[np.ix_(idx_train, idx_train)]
  296. V_train = self.pots[idx_train]
  297. V_test = self.pots[idx_test]
  298. I_matrix = np.identity(len(idx_train))
  299. B_new = np.matrix(B_train) + (lambd*I_matrix)
  300. try:
  301. beta_new = np.dot(np.matrix(B_new).I, np.matrix(V_train))
  302. B_test = self.k_pot[np.ix_(idx_test, idx_train)]
  303. V_est = np.zeros((len(idx_test), self.pots.shape[1]))
  304. for ii in range(len(idx_train)):
  305. for tt in range(self.pots.shape[1]):
  306. V_est[:, tt] += beta_new[ii, tt] * B_test[:, ii]
  307. err += np.linalg.norm(V_est-V_test)
  308. except LinAlgError:
  309. raise LinAlgError('Encoutered Singular Matrix Error: try changing ele_pos slightly')
  310. return err
  311. class KCSD1D(KCSD):
  312. """KCSD1D - The 1D variant for the Kernel Current Source Density method.
  313. This estimates the Current Source Density, for a given configuration of
  314. electrod positions and recorded potentials, in the case of 1D recording
  315. electrodes (laminar probes). The method implented here is based on the
  316. original paper by Jan Potworowski et.al. 2012.
  317. """
  318. def __init__(self, ele_pos, pots, **kwargs):
  319. """Initialize KCSD1D Class.
  320. Parameters
  321. ----------
  322. ele_pos : numpy array
  323. positions of electrodes
  324. pots : numpy array
  325. potentials measured by electrodes
  326. **kwargs
  327. configuration parameters, that may contain the following keys:
  328. src_type : str
  329. basis function type ('gauss', 'step', 'gauss_lim')
  330. Defaults to 'gauss'
  331. sigma : float
  332. space conductance of the tissue in S/m
  333. Defaults to 1 S/m
  334. n_src_init : int
  335. requested number of sources
  336. Defaults to 300
  337. R_init : float
  338. demanded thickness of the basis element
  339. Defaults to 0.23
  340. h : float
  341. thickness of analyzed cylindrical slice
  342. Defaults to 1.
  343. xmin, xmax : floats
  344. boundaries for CSD estimation space
  345. Defaults to min(ele_pos(x)), and max(ele_pos(x))
  346. ext_x : float
  347. length of space extension: x_min-ext_x ... x_max+ext_x
  348. Defaults to 0.
  349. gdx : float
  350. space increments in the estimation space
  351. Defaults to 0.01(xmax-xmin)
  352. lambd : float
  353. regularization parameter for ridge regression
  354. Defaults to 0.
  355. Raises
  356. ------
  357. LinAlgException
  358. If the matrix is not numerically invertible.
  359. KeyError
  360. Basis function (src_type) not implemented. See basis_functions.py for available
  361. """
  362. super(KCSD1D, self).__init__(ele_pos, pots, **kwargs)
  363. def estimate_at(self):
  364. """Defines locations where the estimation is wanted
  365. Defines:
  366. self.n_estm = self.estm_x.size
  367. self.ngx = self.estm_x.shape
  368. self.estm_x : Locations at which CSD is requested.
  369. Parameters
  370. ----------
  371. None
  372. """
  373. nx = (self.xmax - self.xmin)/self.gdx
  374. self.estm_x = np.mgrid[self.xmin:self.xmax:np.complex(0,nx)]
  375. self.n_estm = self.estm_x.size
  376. self.ngx = self.estm_x.shape[0]
  377. def place_basis(self):
  378. """Places basis sources of the defined type.
  379. Checks if a given source_type is defined, if so then defines it
  380. self.basis, This function gives locations of the basis sources,
  381. Defines
  382. source_type : basis_fuctions.basis_1D.keys()
  383. self.R based on R_init
  384. self.dist_max as maximum distance between electrode and basis
  385. self.nsx = self.src_x.shape
  386. self.src_x : Locations at which basis sources are placed.
  387. Parameters
  388. ----------
  389. None
  390. """
  391. source_type = self.src_type
  392. try:
  393. self.basis = basis.basis_1D[source_type]
  394. except KeyError:
  395. raise KeyError('Invalid source_type for basis! available are:',
  396. basis.basis_1D.keys())
  397. (self.src_x, self.R) = utils.distribute_srcs_1D(self.estm_x,
  398. self.n_src_init,
  399. self.ext_x,
  400. self.R_init )
  401. self.n_src = self.src_x.size
  402. self.nsx = self.src_x.shape
  403. def create_src_dist_tables(self):
  404. """Creates distance tables between sources, electrode and estm points
  405. Parameters
  406. ----------
  407. None
  408. """
  409. src_loc = np.array((self.src_x.ravel()))
  410. src_loc = src_loc.reshape((len(src_loc), 1))
  411. est_loc = np.array((self.estm_x.ravel()))
  412. est_loc = est_loc.reshape((len(est_loc), 1))
  413. self.src_ele_dists = distance.cdist(src_loc, self.ele_pos, 'euclidean')
  414. self.src_estm_dists = distance.cdist(src_loc, est_loc, 'euclidean')
  415. self.dist_max = max(np.max(self.src_ele_dists), np.max(self.src_estm_dists)) + self.R
  416. def forward_model(self, x, R, h, sigma, src_type):
  417. """FWD model functions
  418. Evaluates potential at point (x,0) by a basis source located at (0,0)
  419. Eq 26 kCSD by Jan,2012
  420. Parameters
  421. ----------
  422. x : float
  423. R : float
  424. h : float
  425. sigma : float
  426. src_type : basis_1D.key
  427. Returns
  428. -------
  429. pot : float
  430. value of potential at specified distance from the source
  431. """
  432. pot, err = integrate.quad(self.int_pot_1D,
  433. -R, R,
  434. args=(x, R, h, src_type))
  435. pot *= 1./(2.0*sigma)
  436. return pot
  437. def int_pot_1D(self, xp, x, R, h, basis_func):
  438. """FWD model function.
  439. Returns contribution of a point xp,yp, belonging to a basis source
  440. support centered at (0,0) to the potential measured at (x,0),
  441. integrated over xp,yp gives the potential generated by a
  442. basis source element centered at (0,0) at point (x,0)
  443. Eq 26 kCSD by Jan,2012
  444. Parameters
  445. ----------
  446. xp : floats or np.arrays
  447. point or set of points where function should be calculated
  448. x : float
  449. position at which potential is being measured
  450. R : float
  451. The size of the basis function
  452. h : float
  453. thickness of slice
  454. basis_func : method
  455. Fuction of the basis source
  456. Returns
  457. -------
  458. pot : float
  459. """
  460. m = np.sqrt((x-xp)**2 + h**2) - abs(x-xp)
  461. m *= basis_func(abs(xp), R) #xp is the distance
  462. return m
  463. class KCSD2D(KCSD):
  464. """KCSD2D - The 2D variant for the Kernel Current Source Density method.
  465. This estimates the Current Source Density, for a given configuration of
  466. electrod positions and recorded potentials, in the case of 2D recording
  467. electrodes. The method implented here is based on the original paper
  468. by Jan Potworowski et.al. 2012.
  469. """
  470. def __init__(self, ele_pos, pots, **kwargs):
  471. """Initialize KCSD2D Class.
  472. Parameters
  473. ----------
  474. ele_pos : numpy array
  475. positions of electrodes
  476. pots : numpy array
  477. potentials measured by electrodes
  478. **kwargs
  479. configuration parameters, that may contain the following keys:
  480. src_type : str
  481. basis function type ('gauss', 'step', 'gauss_lim')
  482. Defaults to 'gauss'
  483. sigma : float
  484. space conductance of the tissue in S/m
  485. Defaults to 1 S/m
  486. n_src_init : int
  487. requested number of sources
  488. Defaults to 1000
  489. R_init : float
  490. demanded thickness of the basis element
  491. Defaults to 0.23
  492. h : float
  493. thickness of analyzed tissue slice
  494. Defaults to 1.
  495. xmin, xmax, ymin, ymax : floats
  496. boundaries for CSD estimation space
  497. Defaults to min(ele_pos(x)), and max(ele_pos(x))
  498. Defaults to min(ele_pos(y)), and max(ele_pos(y))
  499. ext_x, ext_y : float
  500. length of space extension: x_min-ext_x ... x_max+ext_x
  501. length of space extension: y_min-ext_y ... y_max+ext_y
  502. Defaults to 0.
  503. gdx, gdy : float
  504. space increments in the estimation space
  505. Defaults to 0.01(xmax-xmin)
  506. Defaults to 0.01(ymax-ymin)
  507. lambd : float
  508. regularization parameter for ridge regression
  509. Defaults to 0.
  510. Raises
  511. ------
  512. LinAlgError
  513. Could not invert the matrix, try changing the ele_pos slightly
  514. KeyError
  515. Basis function (src_type) not implemented. See basis_functions.py for available
  516. """
  517. super(KCSD2D, self).__init__(ele_pos, pots, **kwargs)
  518. def estimate_at(self):
  519. """Defines locations where the estimation is wanted
  520. Defines:
  521. self.n_estm = self.estm_x.size
  522. self.ngx, self.ngy = self.estm_x.shape
  523. self.estm_x, self.estm_y : Locations at which CSD is requested.
  524. Parameters
  525. ----------
  526. None
  527. """
  528. nx = (self.xmax - self.xmin)/self.gdx
  529. ny = (self.ymax - self.ymin)/self.gdy
  530. self.estm_x, self.estm_y = np.mgrid[self.xmin:self.xmax:np.complex(0,nx),
  531. self.ymin:self.ymax:np.complex(0,ny)]
  532. self.n_estm = self.estm_x.size
  533. self.ngx, self.ngy = self.estm_x.shape
  534. def place_basis(self):
  535. """Places basis sources of the defined type.
  536. Checks if a given source_type is defined, if so then defines it
  537. self.basis, This function gives locations of the basis sources,
  538. Defines
  539. source_type : basis_fuctions.basis_2D.keys()
  540. self.R based on R_init
  541. self.dist_max as maximum distance between electrode and basis
  542. self.nsx, self.nsy = self.src_x.shape
  543. self.src_x, self.src_y : Locations at which basis sources are placed.
  544. Parameters
  545. ----------
  546. None
  547. """
  548. source_type = self.src_type
  549. try:
  550. self.basis = basis.basis_2D[source_type]
  551. except KeyError:
  552. raise KeyError('Invalid source_type for basis! available are:',
  553. basis.basis_2D.keys())
  554. (self.src_x, self.src_y, self.R) = utils.distribute_srcs_2D(self.estm_x,
  555. self.estm_y,
  556. self.n_src_init,
  557. self.ext_x,
  558. self.ext_y,
  559. self.R_init )
  560. self.n_src = self.src_x.size
  561. self.nsx, self.nsy = self.src_x.shape
  562. def create_src_dist_tables(self):
  563. """Creates distance tables between sources, electrode and estm points
  564. Parameters
  565. ----------
  566. None
  567. """
  568. src_loc = np.array((self.src_x.ravel(), self.src_y.ravel()))
  569. est_loc = np.array((self.estm_x.ravel(), self.estm_y.ravel()))
  570. self.src_ele_dists = distance.cdist(src_loc.T, self.ele_pos, 'euclidean')
  571. self.src_estm_dists = distance.cdist(src_loc.T, est_loc.T, 'euclidean')
  572. self.dist_max = max(np.max(self.src_ele_dists), np.max(self.src_estm_dists)) + self.R
  573. def forward_model(self, x, R, h, sigma, src_type):
  574. """FWD model functions
  575. Evaluates potential at point (x,0) by a basis source located at (0,0)
  576. Eq 22 kCSD by Jan,2012
  577. Parameters
  578. ----------
  579. x : float
  580. R : float
  581. h : float
  582. sigma : float
  583. src_type : basis_2D.key
  584. Returns
  585. -------
  586. pot : float
  587. value of potential at specified distance from the source
  588. """
  589. pot, err = integrate.dblquad(self.int_pot_2D,
  590. -R, R,
  591. lambda x: -R,
  592. lambda x: R,
  593. args=(x, R, h, src_type))
  594. pot *= 1./(2.0*np.pi*sigma) #Potential basis functions bi_x_y
  595. return pot
  596. def int_pot_2D(self, xp, yp, x, R, h, basis_func):
  597. """FWD model function.
  598. Returns contribution of a point xp,yp, belonging to a basis source
  599. support centered at (0,0) to the potential measured at (x,0),
  600. integrated over xp,yp gives the potential generated by a
  601. basis source element centered at (0,0) at point (x,0)
  602. Parameters
  603. ----------
  604. xp, yp : floats or np.arrays
  605. point or set of points where function should be calculated
  606. x : float
  607. position at which potential is being measured
  608. R : float
  609. The size of the basis function
  610. h : float
  611. thickness of slice
  612. basis_func : method
  613. Fuction of the basis source
  614. Returns
  615. -------
  616. pot : float
  617. """
  618. y = ((x-xp)**2 + yp**2)**(0.5)
  619. if y < 0.00001:
  620. y = 0.00001
  621. dist = np.sqrt(xp**2 + yp**2)
  622. pot = np.arcsinh(h/y)*basis_func(dist, R)
  623. return pot
  624. class MoIKCSD(KCSD2D):
  625. """MoIKCSD - CSD while including the forward modeling effects of saline.
  626. This estimates the Current Source Density, for a given configuration of
  627. electrod positions and recorded potentials, in the case of 2D recording
  628. electrodes from an MEA electrode plane using the Method of Images.
  629. The method implented here is based on kCSD method by Jan Potworowski
  630. et.al. 2012, which was extended in Ness, Chintaluri 2015 for MEA.
  631. """
  632. def __init__(self, ele_pos, pots, **kwargs):
  633. """Initialize MoIKCSD Class.
  634. Parameters
  635. ----------
  636. ele_pos : numpy array
  637. positions of electrodes
  638. pots : numpy array
  639. potentials measured by electrodes
  640. **kwargs
  641. configuration parameters, that may contain the following keys:
  642. src_type : str
  643. basis function type ('gauss', 'step', 'gauss_lim')
  644. Defaults to 'gauss'
  645. sigma : float
  646. space conductance of the tissue in S/m
  647. Defaults to 1 S/m
  648. sigma_S : float
  649. conductance of the saline (medium) in S/m
  650. Default is 5 S/m (5 times more conductive)
  651. n_src_init : int
  652. requested number of sources
  653. Defaults to 1000
  654. R_init : float
  655. demanded thickness of the basis element
  656. Defaults to 0.23
  657. h : float
  658. thickness of analyzed tissue slice
  659. Defaults to 1.
  660. xmin, xmax, ymin, ymax : floats
  661. boundaries for CSD estimation space
  662. Defaults to min(ele_pos(x)), and max(ele_pos(x))
  663. Defaults to min(ele_pos(y)), and max(ele_pos(y))
  664. ext_x, ext_y : float
  665. length of space extension: x_min-ext_x ... x_max+ext_x
  666. length of space extension: y_min-ext_y ... y_max+ext_y
  667. Defaults to 0.
  668. gdx, gdy : float
  669. space increments in the estimation space
  670. Defaults to 0.01(xmax-xmin)
  671. Defaults to 0.01(ymax-ymin)
  672. lambd : float
  673. regularization parameter for ridge regression
  674. Defaults to 0.
  675. MoI_iters : int
  676. Number of interations in method of images.
  677. Default is 20
  678. """
  679. self.MoI_iters = kwargs.pop('MoI_iters', 20)
  680. self.sigma_S = kwargs.pop('sigma_S', 5.0)
  681. self.sigma = kwargs.pop('sigma', 1.0)
  682. W_TS = (self.sigma - self.sigma_S) / (self.sigma + self.sigma_S)
  683. self.iters = np.arange(self.MoI_iters) + 1 #Eq 6, Ness (2015)
  684. self.iter_factor = W_TS**self.iters
  685. super(MoIKCSD, self).__init__(ele_pos, pots, **kwargs)
  686. def forward_model(self, x, R, h, sigma, src_type):
  687. """FWD model functions
  688. Evaluates potential at point (x,0) by a basis source located at (0,0)
  689. Eq 22 kCSD by Jan,2012
  690. Parameters
  691. ----------
  692. x : float
  693. R : float
  694. h : float
  695. sigma : float
  696. src_type : basis_2D.key
  697. Returns
  698. -------
  699. pot : float
  700. value of potential at specified distance from the source
  701. """
  702. pot, err = integrate.dblquad(self.int_pot_2D_moi, -R, R,
  703. lambda x: -R,
  704. lambda x: R,
  705. args=(x, R, h, src_type))
  706. pot *= 1./(2.0*np.pi*sigma)
  707. return pot
  708. def int_pot_2D_moi(self, xp, yp, x, R, h, basis_func):
  709. """FWD model function. Incorporates the Method of Images.
  710. Returns contribution of a point xp,yp, belonging to a basis source
  711. support centered at (0,0) to the potential measured at (x,0),
  712. integrated over xp,yp gives the potential generated by a
  713. basis source element centered at (0,0) at point (x,0)
  714. #Eq 20, Ness(2015)
  715. Parameters
  716. ----------
  717. xp, yp : floats or np.arrays
  718. point or set of points where function should be calculated
  719. x : float
  720. position at which potential is being measured
  721. R : float
  722. The size of the basis function
  723. h : float
  724. thickness of slice
  725. basis_func : method
  726. Fuction of the basis source
  727. Returns
  728. -------
  729. pot : float
  730. """
  731. L = ((x-xp)**2 + yp**2)**(0.5)
  732. if L < 0.00001:
  733. L = 0.00001
  734. correction = np.arcsinh((h-(2*h*self.iters))/L) + np.arcsinh((h+(2*h*self.iters))/L)
  735. pot = np.arcsinh(h/L) + np.sum(self.iter_factor*correction)
  736. dist = np.sqrt(xp**2 + yp**2)
  737. pot *= basis_func(dist, R) #Eq 20, Ness et.al.
  738. return pot
  739. class KCSD3D(KCSD):
  740. """KCSD3D - The 3D variant for the Kernel Current Source Density method.
  741. This estimates the Current Source Density, for a given configuration of
  742. electrod positions and recorded potentials, in the case of 2D recording
  743. electrodes. The method implented here is based on the original paper
  744. by Jan Potworowski et.al. 2012.
  745. """
  746. def __init__(self, ele_pos, pots, **kwargs):
  747. """Initialize KCSD3D Class.
  748. Parameters
  749. ----------
  750. ele_pos : numpy array
  751. positions of electrodes
  752. pots : numpy array
  753. potentials measured by electrodes
  754. **kwargs
  755. configuration parameters, that may contain the following keys:
  756. src_type : str
  757. basis function type ('gauss', 'step', 'gauss_lim')
  758. Defaults to 'gauss'
  759. sigma : float
  760. space conductance of the tissue in S/m
  761. Defaults to 1 S/m
  762. n_src_init : int
  763. requested number of sources
  764. Defaults to 1000
  765. R_init : float
  766. demanded thickness of the basis element
  767. Defaults to 0.23
  768. h : float
  769. thickness of analyzed tissue slice
  770. Defaults to 1.
  771. xmin, xmax, ymin, ymax, zmin, zmax : floats
  772. boundaries for CSD estimation space
  773. Defaults to min(ele_pos(x)), and max(ele_pos(x))
  774. Defaults to min(ele_pos(y)), and max(ele_pos(y))
  775. Defaults to min(ele_pos(z)), and max(ele_pos(z))
  776. ext_x, ext_y, ext_z : float
  777. length of space extension: xmin-ext_x ... xmax+ext_x
  778. length of space extension: ymin-ext_y ... ymax+ext_y
  779. length of space extension: zmin-ext_z ... zmax+ext_z
  780. Defaults to 0.
  781. gdx, gdy, gdz : float
  782. space increments in the estimation space
  783. Defaults to 0.01(xmax-xmin)
  784. Defaults to 0.01(ymax-ymin)
  785. Defaults to 0.01(zmax-zmin)
  786. lambd : float
  787. regularization parameter for ridge regression
  788. Defaults to 0.
  789. Raises
  790. ------
  791. LinAlgError
  792. Could not invert the matrix, try changing the ele_pos slightly
  793. KeyError
  794. Basis function (src_type) not implemented. See basis_functions.py for available
  795. """
  796. super(KCSD3D, self).__init__(ele_pos, pots, **kwargs)
  797. def estimate_at(self):
  798. """Defines locations where the estimation is wanted
  799. Defines:
  800. self.n_estm = self.estm_x.size
  801. self.ngx, self.ngy, self.ngz = self.estm_x.shape
  802. self.estm_x, self.estm_y, self.estm_z : Pts. at which CSD is requested
  803. Parameters
  804. ----------
  805. None
  806. """
  807. nx = (self.xmax - self.xmin)/self.gdx
  808. ny = (self.ymax - self.ymin)/self.gdy
  809. nz = (self.zmax - self.zmin)/self.gdz
  810. self.estm_x, self.estm_y, self.estm_z = np.mgrid[self.xmin:self.xmax:np.complex(0,nx),
  811. self.ymin:self.ymax:np.complex(0,ny),
  812. self.zmin:self.zmax:np.complex(0,nz)]
  813. self.n_estm = self.estm_x.size
  814. self.ngx, self.ngy, self.ngz = self.estm_x.shape
  815. def place_basis(self):
  816. """Places basis sources of the defined type.
  817. Checks if a given source_type is defined, if so then defines it
  818. self.basis, This function gives locations of the basis sources,
  819. Defines
  820. source_type : basis_fuctions.basis_2D.keys()
  821. self.R based on R_init
  822. self.dist_max as maximum distance between electrode and basis
  823. self.nsx, self.nsy, self.nsz = self.src_x.shape
  824. self.src_x, self.src_y, self.src_z : Locations at which basis sources are placed.
  825. Parameters
  826. ----------
  827. None
  828. """
  829. source_type = self.src_type
  830. try:
  831. self.basis = basis.basis_3D[source_type]
  832. except KeyError:
  833. raise KeyError('Invalid source_type for basis! available are:',
  834. basis.basis_3D.keys())
  835. (self.src_x, self.src_y, self.src_z, self.R) = utils.distribute_srcs_3D(self.estm_x,
  836. self.estm_y,
  837. self.estm_z,
  838. self.n_src_init,
  839. self.ext_x,
  840. self.ext_y,
  841. self.ext_z,
  842. self.R_init)
  843. self.n_src = self.src_x.size
  844. self.nsx, self.nsy, self.nsz = self.src_x.shape
  845. def create_src_dist_tables(self):
  846. """Creates distance tables between sources, electrode and estm points
  847. Parameters
  848. ----------
  849. None
  850. """
  851. src_loc = np.array((self.src_x.ravel(),
  852. self.src_y.ravel(),
  853. self.src_z.ravel()))
  854. est_loc = np.array((self.estm_x.ravel(),
  855. self.estm_y.ravel(),
  856. self.estm_z.ravel()))
  857. self.src_ele_dists = distance.cdist(src_loc.T, self.ele_pos, 'euclidean')
  858. self.src_estm_dists = distance.cdist(src_loc.T, est_loc.T, 'euclidean')
  859. self.dist_max = max(np.max(self.src_ele_dists), np.max(self.src_estm_dists)) + self.R
  860. def forward_model(self, x, R, h, sigma, src_type):
  861. """FWD model functions
  862. Evaluates potential at point (x,0) by a basis source located at (0,0)
  863. Utlizies sk monaco monte carlo method if available, otherwise defaults
  864. to scipy integrate
  865. Parameters
  866. ----------
  867. x : float
  868. R : float
  869. h : float
  870. sigma : float
  871. src_type : basis_3D.key
  872. Returns
  873. -------
  874. pot : float
  875. value of potential at specified distance from the source
  876. """
  877. if src_type.__name__ == "gauss_3D":
  878. if x == 0: x=0.0001
  879. pot = special.erf(x/(np.sqrt(2)*R/3.0)) / x
  880. elif src_type.__name__ == "gauss_lim_3D":
  881. if x == 0: x=0.0001
  882. d = R/3.
  883. if x < R:
  884. e = np.exp(-(x/ (np.sqrt(2)*d))**2)
  885. erf = special.erf(x / (np.sqrt(2)*d))
  886. pot = 4* np.pi * ( (d**2)*(e - np.exp(-4.5)) +
  887. (1/x)*((np.sqrt(np.pi/2)*(d**3)*erf) - x*(d**2)*e))
  888. else:
  889. pot = 15.28828*(d)**3 / x
  890. pot /= (np.sqrt(2*np.pi)*d)**3
  891. elif src_type.__name__ == "step_3D":
  892. Q = 4.*np.pi*(R**3)/3.
  893. if x < R:
  894. pot = (Q * (3 - (x/R)**2)) / (2.*R)
  895. else:
  896. pot = Q / x
  897. pot *= 3/(4*np.pi*R**3)
  898. else:
  899. if skmonaco_available:
  900. pot, err = mcmiser(self.int_pot_3D_mc,
  901. npoints=1e5,
  902. xl=[-R, -R, -R],
  903. xu=[R, R, R],
  904. seed=42,
  905. nprocs=num_cores,
  906. args=(x, R, h, src_type))
  907. else:
  908. pot, err = integrate.tplquad(self.int_pot_3D,
  909. -R,
  910. R,
  911. lambda x: -R,
  912. lambda x: R,
  913. lambda x, y: -R,
  914. lambda x, y: R,
  915. args=(x, R, h, src_type))
  916. pot *= 1./(4.0*np.pi*sigma)
  917. return pot
  918. def int_pot_3D(self, xp, yp, zp, x, R, h, basis_func):
  919. """FWD model function.
  920. Returns contribution of a point xp,yp, belonging to a basis source
  921. support centered at (0,0) to the potential measured at (x,0),
  922. integrated over xp,yp gives the potential generated by a
  923. basis source element centered at (0,0) at point (x,0)
  924. Parameters
  925. ----------
  926. xp, yp, zp : floats or np.arrays
  927. point or set of points where function should be calculated
  928. x : float
  929. position at which potential is being measured
  930. R : float
  931. The size of the basis function
  932. h : float
  933. thickness of slice
  934. basis_func : method
  935. Fuction of the basis source
  936. Returns
  937. -------
  938. pot : float
  939. """
  940. y = ((x-xp)**2 + yp**2 + zp**2)**0.5
  941. if y < 0.00001:
  942. y = 0.00001
  943. dist = np.sqrt(xp**2 + yp**2 + zp**2)
  944. pot = 1.0/y
  945. pot *= basis_func(dist, R)
  946. return pot
  947. def int_pot_3D_mc(self, xyz, x, R, h, basis_func):
  948. """
  949. The same as int_pot_3D, just different input: x,y,z <-- xyz (tuple)
  950. FWD model function, using Monte Carlo Method of integration
  951. Returns contribution of a point xp,yp, belonging to a basis source
  952. support centered at (0,0) to the potential measured at (x,0),
  953. integrated over xp,yp gives the potential generated by a
  954. basis source element centered at (0,0) at point (x,0)
  955. Parameters
  956. ----------
  957. xp, yp, zp : floats or np.arrays
  958. point or set of points where function should be calculated
  959. x : float
  960. position at which potential is being measured
  961. R : float
  962. The size of the basis function
  963. h : float
  964. thickness of slice
  965. basis_func : method
  966. Fuction of the basis source
  967. Returns
  968. -------
  969. pot : float
  970. """
  971. xp, yp, zp = xyz
  972. return self.int_pot_3D(xp, yp, zp, x, R, h, basis_func)
  973. if __name__ == '__main__':
  974. print('Checking 1D')
  975. ele_pos = np.array(([-0.1],[0], [0.5], [1.], [1.4], [2.], [2.3]))
  976. pots = np.array([[-1], [-1], [-1], [0], [0], [1], [-1.5]])
  977. k = KCSD1D(ele_pos, pots,
  978. gdx=0.01, n_src_init=300,
  979. ext_x=0.0, src_type='gauss')
  980. k.cross_validate()
  981. print(k.values())
  982. print('Checking 2D')
  983. ele_pos = np.array([[-0.2, -0.2],[0, 0], [0, 1], [1, 0], [1,1], [0.5, 0.5],
  984. [1.2, 1.2]])
  985. pots = np.array([[-1], [-1], [-1], [0], [0], [1], [-1.5]])
  986. k = KCSD2D(ele_pos, pots,
  987. gdx=0.05, gdy=0.05,
  988. xmin=-2.0, xmax=2.0,
  989. ymin=-2.0, ymax=2.0,
  990. src_type='gauss')
  991. k.cross_validate()
  992. print(k.values())
  993. print('Checking MoIKCSD')
  994. k = MoIKCSD(ele_pos, pots,
  995. gdx=0.05, gdy=0.05,
  996. xmin=-2.0, xmax=2.0,
  997. ymin=-2.0, ymax= 2.0)
  998. k.cross_validate()
  999. print('Checking KCSD3D')
  1000. ele_pos = np.array([(0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0),
  1001. (0, 1, 1), (1, 1, 0), (1, 0, 1), (1, 1, 1),
  1002. (0.5, 0.5, 0.5)])
  1003. pots = np.array([[-0.5], [0], [-0.5], [0], [0], [0.2], [0], [0], [1]])
  1004. k = KCSD3D(ele_pos, pots,
  1005. gdx=0.02, gdy=0.02, gdz=0.02,
  1006. n_src_init=1000, src_type='gauss_lim')
  1007. k.cross_validate()
  1008. print(k.values())