fit_pf_nls.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377
  1. #!/user/bin/env python
  2. # coding=utf-8
  3. """
  4. @author: yannansu
  5. @created at: 22.03.21 18:10
  6. Fit Psychometric functions using non-linear least squares.
  7. Example usage:
  8. test_data = LoadData('s01', data_path='data', sel_par=['LL_2x2']).read_data()
  9. test_FitPf = FitPf(test_data) # or test_FitPf = FitPf_Correctness(test_data)
  10. test_fit = test_FitPf.fit()
  11. test_FitPf.plot_pf_curve()
  12. test_FitPf.plot_pf_param() # not available for FitPf_Correctness
  13. """
  14. import pandas as pd
  15. import numpy as np
  16. import pylab
  17. import matplotlib as mpl
  18. import matplotlib.pyplot as plt
  19. from psychopy import data
  20. from .fit_curves import FitCumNormal
  21. from .color4plot import color4plot
  22. from matplotlib import rc
  23. # Set the global font to be DejaVu Sans, size 10 (or any other sans-serif font of your choice!)
  24. rc("font", **{'family': 'sans-serif', 'sans-serif': ['DejaVu Sans'], 'size': 15})
  25. # Set the font used for MathJax - more on this later
  26. rc('mathtext', **{"default": 'regular'})
  27. # Define params for plotting
  28. plot_config = {"figsize": (16, 10),
  29. 'title_fontsize': 13,
  30. 'label_fontsize': 12,
  31. 'tick_size': 8
  32. }
  33. # mpl.rcParams['agg.path.chunksize'] = 1000000000
  34. class FitPf:
  35. """
  36. Fit 2AFC data to a psychometric curve (response 'Yes' - stimulus intensity)
  37. Example usage:
  38. df = LoadData('test', sel_par=['LL_set1']).read_data() # load data
  39. fit_dat = FitPf(df).fit() # do the fit
  40. FitPf(df).plot_pf_curve() # plot PF curves
  41. FitPf(df).plot_pf_param() # plot PF estimates
  42. """
  43. def __init__(self, df, guess=None, lapse=0.0, bins='unique'):
  44. """
  45. :param df: psychometric data (Dataframe)
  46. :param guess: guess parameter for fitting, default None as [0., 5.] (corresponding to [PSE, JND])
  47. :param lapse: lapse rate, default as 0.
  48. :param bins: number of bins for binning data before fitting, default as 'unique'
  49. """
  50. self.df = df
  51. if guess is None:
  52. self.guess = [0., 5.]
  53. else:
  54. self.guess = guess
  55. self.lapse = lapse
  56. self.bins = bins
  57. def fit(self):
  58. df_dict = dict(list(self.df.groupby('standard_stim')))
  59. # A nested dictionary
  60. fit_dat = {}
  61. for key, d in df_dict.items():
  62. fit_dat[key] = d.to_dict(orient='list')
  63. fit_dat[key]['Hue Angle'] = key
  64. fit_dat[key]['Trial N'] = len(df_dict[key])
  65. # Bin data
  66. fit_dat[key]['Binned Intensities'], \
  67. fit_dat[key]['Binned Responses'], \
  68. fit_dat[key]['Binned N'] = data.functionFromStaircase(
  69. fit_dat[key]['actual_intensity'],
  70. fit_dat[key]['resp_as_larger'],
  71. bins=self.bins)
  72. # Sems is defined as 1/weight in Psychopy
  73. fit_dat[key]['sems'] = [sum(fit_dat[key]['Binned N']) / n
  74. for n in fit_dat[key]['Binned N']]
  75. fit_dat[key]['fit'] = FitCumNormal(fit_dat[key]['Binned Intensities'],
  76. fit_dat[key]['Binned Responses'],
  77. sems=fit_dat[key]['sems'], guess=self.guess,
  78. expectedMin=0.0, lapse=self.lapse) # customized cumulative Gaussian
  79. fit_dat[key]['PSE'], fit_dat[key]['JND'] = fit_dat[key]['fit'].params
  80. fit_dat[key]['PSE_err'], fit_dat[key]['JND_err'] = np.sqrt(np.diagonal(fit_dat[key]['fit'].covar))
  81. fit_dat[key]['ssq'] = fit_dat[key]['fit'].ssq
  82. return fit_dat
  83. def plot_pf_curve(self, save_pdf=None):
  84. """
  85. Plot PFs from nonlinear least squares fitting results.
  86. """
  87. fit_dat = self.fit()
  88. num = len(fit_dat)
  89. if num == 1:
  90. fig, axes = plt.subplots(num, 1, figsize=plot_config['figsize'])
  91. else:
  92. fig, axes = plt.subplots(2, int(num / 2), figsize=plot_config['figsize'])
  93. xlim = [-20, 20]
  94. ylim = [0, 1.0]
  95. for idx, key in enumerate(fit_dat):
  96. this_dat = fit_dat[key]
  97. this_fit = this_dat['fit']
  98. ntrial = this_dat['Trial N']
  99. if num == 1:
  100. ax = axes
  101. else:
  102. ax = axes.flatten()[idx]
  103. hue_angle = this_dat['Hue Angle']
  104. color_code = color4plot(hue_angle)[0]
  105. # hue_angles = np.array([float(k) for k in fit_dat.keys()])
  106. # color = colorcodes[int((key - self.first_angle) / self.first_angle / 2)]
  107. for inten, resp, se in zip(this_dat['Binned Intensities'],
  108. this_dat['Binned Responses'],
  109. this_dat['sems']):
  110. ax.plot(inten, resp, '.', color=color_code, alpha=0.5, markersize=30 / np.log(se))
  111. smoothResp = pylab.arange(0.0, 1.0, .02)
  112. smoothInt = this_fit.inverse(smoothResp)
  113. ax.plot(smoothInt, smoothResp, '-', color=color_code) # plot fitted curve
  114. for val in [0.25, 0.5, 0.75]:
  115. ax.hlines(y=val, xmin=xlim[0], xmax=this_fit.inverse(val), linestyles='dashed', colors='grey')
  116. ax.vlines(x=this_fit.inverse(val), ymin=ylim[0], ymax=val, linestyles='dashed', colors='grey')
  117. ssq = np.round(this_dat['ssq'], decimals=3) # sum-squared error
  118. ax.text(3.5, 0.55, 'ssq = ' + str(ssq), fontsize=plot_config['tick_size'])
  119. ax.set_title('hue_angle: ' + str(hue_angle) + ', ' + '%dtrials' % ntrial,
  120. fontsize=plot_config['title_fontsize'])
  121. ax.set_xlim(xlim)
  122. # ax.set_ylim(ylim)
  123. ax.tick_params(axis='both', which='major', labelsize=plot_config['tick_size'])
  124. if num == 1:
  125. x_ax = ax
  126. y_ax = ax
  127. elif num == 2:
  128. x_ax = axes[-1]
  129. y_ax = axes[0]
  130. else:
  131. x_ax = axes[-1, :]
  132. y_ax = axes[:, 0]
  133. plt.setp(x_ax, xlabel='Hue Angle')
  134. plt.setp(y_ax, ylabel='Response "Test hue angle is larger" ')
  135. plt.setp(ax.get_xticklabels(), fontsize=plot_config['tick_size'])
  136. plt.setp(ax.get_yticklabels(), fontsize=plot_config['tick_size'])
  137. # fig.suptitle(self.sub[0:2] + '_' + str(ntrial) + 'trials', fontsize=plot_config['title_fontsize'])
  138. fig.suptitle(str(ntrial) + ' trials', fontsize=plot_config['title_fontsize'])
  139. if save_pdf is not None:
  140. plt.savefig('data_analysis/figures/' + save_pdf + '.pdf')
  141. plt.show()
  142. def plot_pf_param(self, save_pdf=None):
  143. """
  144. Plot estimated PF parameters from nonlinear least squares fitting results.
  145. """
  146. fit_dat = pd.DataFrame(self.fit()).T
  147. num = len(fit_dat)
  148. hue_angles = fit_dat.index
  149. color_codes = color4plot(hue_angles)
  150. plt.figure(figsize=(6, 5))
  151. plt.title('Cumulative Gaussian Parameter Etimates, ' + '%dtrials' % fit_dat['Trial N'].unique()[0],
  152. fontsize=plot_config['title_fontsize'])
  153. plt.xlabel('Hue Angle', fontsize=plot_config['label_fontsize'])
  154. plt.ylabel('Parameter Estimates', fontsize=plot_config['label_fontsize'])
  155. ax = plt.subplot(111)
  156. ax.errorbar(hue_angles, fit_dat.PSE, yerr=fit_dat.PSE_err, label='PSE', ls='-', color=[0.3, 0.3, 0.3])
  157. ax.errorbar(hue_angles, fit_dat.JND, yerr=fit_dat.JND_err, label='JND', ls='-', color=[0.6, 0.6, 0.6])
  158. ax.scatter(hue_angles, fit_dat.PSE, color=color_codes, s=60)
  159. ax.scatter(hue_angles, fit_dat.JND, color=color_codes, s=60)
  160. ax.hlines(0, 0, 360, linestyles='dashed', color='silver')
  161. # xlabels = [f"{l}\n{a}" for l, a in zip(psypar['hue_id'], psypar['angle'])]
  162. # xlabels = hue_angles
  163. ax.set_xticks(hue_angles)
  164. ax.set_xticklabels(hue_angles, rotation=45, fontsize=plot_config['tick_size'])
  165. ax.set_xlim([0, 360])
  166. plt.legend(fontsize=plot_config['tick_size'])
  167. if save_pdf is not None:
  168. plt.savefig('data_analysis/figures/' + save_pdf + '.pdf')
  169. plt.show()
  170. class FitPf_Correctness:
  171. """
  172. Fit 2AFC data to a psychometric curve (response correctness - (absolute) stimulus intensity)
  173. Note it is similar to the class FitPF, but with different Y values.
  174. """
  175. def __init__(self, df, guess=None, lapse=0.0, bins=None, func='CumNormal'):
  176. """
  177. :param df: psychometric data (Dataframe)
  178. :param guess: guess parameter for fitting, default None as [0., 5.] (corresponding to [PSE, JND])
  179. :param lapse: lapse rate, default as 0.
  180. :param bins: number of bins for binning data before fitting, default as 'unique'
  181. """
  182. self.df = df
  183. self.df['labeled_stim'] = (-1) ** (self.df['actual_intensity'] < 0) * self.df[
  184. 'standard_stim'] # to make the following fitting easier, create a new column of standard stimulus with sign labels
  185. if guess is None:
  186. self.guess = [5., 1.]
  187. else:
  188. self.guess = guess
  189. self.lapse = lapse
  190. if bins is None:
  191. self.bins = 'unique'
  192. else:
  193. self.bins = bins
  194. self.func = func
  195. def fit(self):
  196. df_dict = dict(list(self.df.groupby(['standard_stim', 'labeled_stim'])))
  197. # A nested dictionary
  198. fit_dat = {}
  199. for key, d in df_dict.items():
  200. fit_dat[key] = d.to_dict(orient='list')
  201. fit_dat[key]['Hue Angle'] = key[0]
  202. fit_dat[key]['labeled_stim'] = key[1]
  203. fit_dat[key]['Trial N'] = len(df_dict[key])
  204. fit_dat[key]['actual_intensity'] = [abs(x) for x in fit_dat[key]['actual_intensity']]
  205. # Bin data
  206. fit_dat[key]['Binned Intensities'], \
  207. fit_dat[key]['Binned Responses'], \
  208. fit_dat[key]['Binned N'] = data.functionFromStaircase(
  209. (fit_dat[key]['actual_intensity']),
  210. fit_dat[key]['judge'],
  211. bins=self.bins)
  212. # Sems is defined as 1/weight in Psychopy
  213. fit_dat[key]['sems'] = [sum(fit_dat[key]['Binned N']) / n
  214. for n in fit_dat[key]['Binned N']]
  215. if self.func == 'CumNormal':
  216. # customized cumulative Gaussian
  217. fit_dat[key]['fit'] = FitCumNormal(fit_dat[key]['Binned Intensities'],
  218. fit_dat[key]['Binned Responses'],
  219. sems=fit_dat[key]['sems'], guess=self.guess,
  220. expectedMin=0.5, lapse=self.lapse)
  221. elif self.func == 'Weibull':
  222. fit_dat[key]['fit'] = data.FitWeibull(fit_dat[key]['Binned Intensities'],
  223. fit_dat[key]['Binned Responses'],
  224. sems=fit_dat[key]['sems'], guess=self.guess,
  225. expectedMin=0.5)
  226. elif self.func == 'Logistic':
  227. fit_dat[key]['fit'] = data.FitLogistic(fit_dat[key]['Binned Intensities'],
  228. fit_dat[key]['Binned Responses'],
  229. sems=fit_dat[key]['sems'], guess=self.guess,
  230. expectedMin=0.5)
  231. else:
  232. raise ValueError("Given fitting function is not found.")
  233. fit_dat[key]['PSE'], fit_dat[key]['JND'] = fit_dat[key]['fit'].params
  234. fit_dat[key]['PSE_err'], fit_dat[key]['JND_err'] = np.sqrt(np.diagonal(fit_dat[key]['fit'].covar))
  235. fit_dat[key]['ssq'] = fit_dat[key]['fit'].ssq
  236. return fit_dat
  237. def plot_pf_curve(self):
  238. """
  239. Plot PFs from nonlinear least squares fitting results.
  240. """
  241. fit_dat = self.fit()
  242. num = len(fit_dat)
  243. if num == 1:
  244. fig, axes = plt.subplots(num, 1, figsize=plot_config['figsize'])
  245. else:
  246. fig, axes = plt.subplots(2, int(num / 4), figsize=plot_config['figsize'])
  247. xlim = [0, 12]
  248. ylim = [.5, 1.]
  249. for idx, key in enumerate(fit_dat):
  250. this_dat = fit_dat[key]
  251. this_fit = this_dat['fit']
  252. ntrial = this_dat['Trial N']
  253. if num == 1:
  254. ax = axes
  255. else:
  256. ax = axes.flatten()[int(np.floor(idx / 2))]
  257. if key[1] < 0:
  258. label = 'minus'
  259. marker = 'o'
  260. color = 'coral'
  261. else:
  262. label = 'plus'
  263. marker = 'P'
  264. color = 'skyblue'
  265. hue_angle = this_dat['Hue Angle']
  266. # color_code = color4plot(hue_angle)[0]
  267. # hue_angles = np.array([float(k) for k in fit_dat.keys()])
  268. # color = colorcodes[int((key - self.first_angle) / self.first_angle / 2)]
  269. for inten, resp, se in zip(this_dat['Binned Intensities'],
  270. this_dat['Binned Responses'],
  271. this_dat['sems']):
  272. ax.plot(inten, resp, '.', color=color, alpha=.5, markersize=5 / np.log10(se), marker=marker)
  273. smoothInt = pylab.arange(0, 12, .5)
  274. smoothResp = this_fit.eval(smoothInt)
  275. ax.plot(smoothInt, smoothResp, '--', color=color, label=label) # plot fitted curve
  276. # ax.hlines(y=0.75, xmin=0, xmax=this_fit.inverse(0.75), linestyles='dashed', colors='grey')
  277. # ax.vlines(x=this_fit.inverse(0.75), ymin=0.5, ymax=0.75, linestyles='dashed', colors='grey')
  278. # ssq = np.round(this_dat['ssq'], decimals=3) # sum-squared error
  279. # ax.text(3.5, 0.55, 'ssq = ' + str(ssq), fontsize=plot_config['tick_size'])
  280. ax.set_title('hue_angle: ' + str(hue_angle) + ', ' + '%dtrials' % ntrial,
  281. fontsize=plot_config['title_fontsize'])
  282. ax.set_xlim(xlim)
  283. ax.set_ylim(ylim)
  284. ax.tick_params(axis='both', which='major', labelsize=plot_config['tick_size'])
  285. if num == 1:
  286. x_ax = ax
  287. y_ax = ax
  288. elif num == 2:
  289. x_ax = axes[-1]
  290. y_ax = axes[0]
  291. else:
  292. x_ax = axes[-1, :]
  293. y_ax = axes[:, 0]
  294. plt.setp(x_ax, xlabel='Hue Angle')
  295. plt.setp(y_ax, ylabel='Correct Response')
  296. plt.setp(ax.get_xticklabels(), fontsize=plot_config['tick_size'])
  297. plt.setp(ax.get_yticklabels(), fontsize=plot_config['tick_size'])
  298. # fig.suptitle(self.sub[0:2] + '_' + str(ntrial) + 'trials', fontsize=plot_config['title_fontsize'])
  299. fig.suptitle(str(ntrial) + ' trials', fontsize=plot_config['title_fontsize'])
  300. plt.legend()
  301. plt.show()
  302. # s05_lh = LoadData('s05', data_path='data', sel_par=['LH_2x2']).read_data()
  303. # # test_fit = FitPf_Correctness(s05_lh, bins=6).fit()
  304. # # FitPf(s05_lh, guess=[1, 2]).plot_pf_curve()
  305. # FitPf_Correctness(s05_lh, guess=[5, 5], func='CumNormal').plot_pf_curve()