kmodelY.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. from scipy.optimize import least_squares
  2. from statsmodels.stats.stattools import durbin_watson
  3. import numpy as np
  4. def fitKmodel(subdata, nolog=None, pfit=None, p0=None):
  5. """
  6. Parameters:
  7. - subdata : subject data
  8. - nolog : If zero, logarithm is not used (default is 0)
  9. - pfit : A list of logical values to indicate which parameters to fit (default is [True, True, True])
  10. - p0 : Initial parameters (must always have length 3)
  11. Returns:
  12. - px : Parameters of the model
  13. By S.Glasauer 2019 (matlab), translated to Python by Strongway
  14. # add AIC and DW
  15. """
  16. # Handle default arguments
  17. if p0 is None:
  18. p0 = [1., 1, 0]
  19. if pfit is None:
  20. pfit = [True, True, True]
  21. if nolog is None:
  22. nolog = 0
  23. # Convert pfit to logical and filter p0
  24. pfit = np.array(pfit, dtype=bool)
  25. p0 = np.array(p0)[pfit]
  26. # Lower bounds (lb) for the optimization
  27. lb = np.array([0, 0, -np.inf])[pfit]
  28. # extract Duration and Reproduction from subdata as 2d array
  29. x = subdata['Duration'].values
  30. y = subdata['Reproduction'].values
  31. # replace extreme y with nan with y > 3 * x or y < x/3
  32. y[(y > 3 * x) | (y < x/3)] = np.nan
  33. # combine x,y as 2d array
  34. stimrep = np.vstack([x,y]).T
  35. # Perform the optimization using least_squares (equivalent to lsqnonlin in MATLAB)
  36. result = least_squares(kmodelY, p0, args = (stimrep, 1),
  37. bounds=(lb, np.inf), method='trf')
  38. # calculate kalmann filter parameters
  39. q11 = result.x[0]
  40. q22 = result.x[1]
  41. r = 1
  42. # calculate residual sum of squares
  43. rss = np.sum(result.fun**2)
  44. dw = durbin_watson(result.fun)
  45. # number of parameters
  46. k = len(result.x)
  47. # number of observations
  48. n = len(stimrep)
  49. # calculate the log-likelihood
  50. ll = -n/2*(np.log(2*np.pi) + np.log(rss/n) + 1)
  51. # calculate the Akaike information criterion (AIC)
  52. aic = 2*k - 2*ll
  53. # steady state solution
  54. p22 = (q22+np.sqrt(q22*q22+4*(q11+r)*q22))/2
  55. K = np.array([p22 + q11, p22])/(p22+q11+r)
  56. # return the optimized parameters, steady state solution, and AIC
  57. return np.append(np.append(result.x, K), [aic, dw]) # Optimized parameters
  58. def kmodelY(par, stimrep, nolog=1, pfit=[1, 1, 1]):
  59. """
  60. Function to perform Kalman filter-based estimation.
  61. Parameters:
  62. - par: Model parameters (if pfit = [1,1,1], then par = [q1/r, q2/r, cost-related parameter (0 for median)])
  63. - stimrep: Stimulus representation
  64. - nolog: Flag to decide if logarithm transformation is needed
  65. - pfit: Parameter fitting list (note: len(par) = sum(pfit))
  66. Returns:
  67. - sres: Stimulus residuals
  68. - xest: Estimated state
  69. - pest: Estimate error covariance
  70. - resp: Response
  71. - perr: Prediction error
  72. S.Glasauer 2019/2023, translated to Python by Strongway
  73. """
  74. # Convert pfit to a boolean array
  75. pfit = np.array(pfit, dtype=bool)
  76. # Adjust pfit based on the size of par
  77. if len(par) < 3:
  78. pfit[len(par):] = False
  79. # Adjust stimrep's shape for further processing
  80. if stimrep.shape[1] == 1:
  81. stimrep = np.tile(stimrep, (1, 2))
  82. # the first column is the stimulus, the second column is the response,
  83. # and add the third column to indicate the start of a new sequence
  84. #if stimrep.shape[1] == 2:
  85. # stimrep = np.hstack((stimrep, np.zeros((stimrep.shape[0], 1))))
  86. # stimrep[0, 2] = 1
  87. # Initialize pars and overwrite with provided parameters based on pfit
  88. pars = np.array([0.0, 0.0, 0.0])
  89. pars[pfit] = par
  90. par = pars
  91. # Constants for the model
  92. a = 10.0
  93. off = 1.
  94. r = 1.
  95. q1 = par[0] * r
  96. q2 = par[1] * r
  97. # Define matrices Q, P, H, and F for the Kalman filter of two-state model
  98. # details see Glasauer & Shi, 2022, Sci. Rep., https://doi.org/10.1038/s41598-022-14939-8
  99. Q = np.array([[q1, 0], [0, q2]])
  100. P = np.array([[r, 0], [0, r]])
  101. H = np.array([[1., 0]])
  102. F = np.array([[0, 1.], [0, 1.]])
  103. # Apply logarithm transformation if nolog is false
  104. if nolog:
  105. z = stimrep[:, 0]
  106. else: # log transformation
  107. z = np.log(a * stimrep[:, 0] + off)
  108. # Initialize state vector x
  109. x = np.array([[z[0]], [z[0]]])
  110. # Initialize matrices for storing results
  111. xest = np.zeros((len(z), 2))
  112. pest = np.zeros((len(z), 2))
  113. perr = np.zeros(len(z))
  114. # Kalman filter estimation loop
  115. for i in range(len(z)):
  116. x = F@x
  117. P = F@P@F.T + Q
  118. K = P@H.T/(H@P@H.T + r)
  119. perr[i] = z[i] - H@x
  120. x = x + K*perr[i]
  121. P = (np.eye(2) - K@H)@P
  122. pest[i, :] = np.diag(P)
  123. xest[i, :] = x.reshape(-1)
  124. # Adjust for third parameter, if present
  125. if len(par) == 3:
  126. sh = par[2]
  127. else:
  128. sh = 0
  129. # Compute response, adjusting for logarithm if needed
  130. if nolog:
  131. resp = xest[:, 0] + sh
  132. else: # log transformation
  133. resp = (np.exp(xest[:, 0] + sh) - off)/a
  134. # Calculate stimulus residuals
  135. sres = stimrep[:, 1] - resp
  136. # Remove NaNs from sres
  137. sres = sres[np.isfinite(sres)]
  138. return sres