fit_pf_bayes.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. #!/user/bin/env python
  2. # coding=utf-8
  3. """
  4. @author: yannansu
  5. @created at: 14.06.21 11:51
  6. Fit Psychometric functions using Bayesfit.
  7. Example usage:
  8. test_data = LoadData('s01', data_path='data', sel_par=['LL_2x2']).read_data()
  9. test_FitPf = FitPf_bayes(test_data, ylabel='correct', func='norm', params=[[None, None], [True, True, False, False]])
  10. fit_dict = test_FitPf.fit()
  11. fit_df = test_FitPf._to_df()
  12. test_FitPf.plot_pf_curve()
  13. """
  14. import pandas as pd
  15. import numpy as np
  16. import matplotlib.pyplot as plt
  17. from data_analysis.color4plot import color4plot
  18. import bayesfit as bf
  19. from bayesfit.psyFunction import psyfunction as _psyfunction
  20. from matplotlib import rc
  21. from data_analysis.load_data import LoadData
  22. # Set the global font to be DejaVu Sans, size 10 (or any other sans-serif font of your choice!)
  23. rc("font", **{'family': 'sans-serif', 'sans-serif': ['DejaVu Sans'], 'size': 15})
  24. # Set the font used for MathJax - more on this later
  25. rc('mathtext', **{"default": 'regular'})
  26. # Define params for plotting
  27. plot_config = {"figsize": (16, 10),
  28. 'title_fontsize': 13,
  29. 'label_fontsize': 12,
  30. 'tick_size': 8
  31. }
  32. class FitPf_bayes:
  33. def __init__(self, df, ylabel="response", func='norm', params=None):
  34. self.df = df
  35. self.df['labeled_stim'] = (-1) ** (self.df['actual_intensity'] < 0) * self.df[
  36. 'standard_stim'] # to make the following fitting easier, create a new column of standard stimulus with sign labels
  37. self.ylabel = ylabel
  38. if self.ylabel not in ['response', 'correct']:
  39. raise ValueError("Given ylabel dose not apply!")
  40. self.func = func
  41. if self.func not in ["norm", "logistic", "weibull"]:
  42. raise ValueError("Given fitting function type dose not apply!")
  43. self.params = params
  44. if self.params is None:
  45. self.params = [[None, None],
  46. [True, True, False, False]]
  47. def fit(self):
  48. """
  49. Fit PF to data using Bayes fitting.
  50. :return:
  51. """
  52. if self.ylabel == 'response':
  53. grp_keys = ['standard_stim']
  54. elif self.ylabel == 'correct':
  55. grp_keys = ['standard_stim', 'labeled_stim']
  56. df_dict = dict(list(self.df.groupby(grp_keys)))
  57. # fit_dict = pd.DataFrame(columns=['scale', 'slope', 'threshold'])
  58. fit_dict = {}
  59. for key, df in df_dict.items():
  60. fit_dict[key] = {}
  61. fit_dict[key]['ylabel'] = self.ylabel
  62. fit_dict[key]['ntrial'] = len(df)
  63. if self.ylabel == 'response':
  64. fit_dict[key]['hue_angle'] = key
  65. data = df.groupby('actual_intensity')['resp_as_larger'].agg(['sum', 'count'])
  66. chance = 0.
  67. threshold = 0.5
  68. intensity = data.index.values
  69. else:
  70. fit_dict[key]['hue_angle'] = key[0]
  71. fit_dict[key]['labeled_stim'] = key[1]
  72. data = df.groupby('actual_intensity')['judge'].agg(['sum', 'count'])
  73. intensity = abs(data.index.values)
  74. chance = 0.5
  75. threshold = 0.75
  76. # Reform data as a m-row by 3-column Numpy array:
  77. # Stimulus intensity N-trials correct N-trials total
  78. data_matrix = np.transpose([intensity,
  79. data['sum'].values,
  80. data['count'].values])
  81. metrics, options = bf.fitmodel(data_matrix, nafc=2, sigmoid_type=self.func,
  82. threshold=threshold, density=100,
  83. param_ests=[self.params[0][0], self.params[0][1], chance, 0.001],
  84. param_free=self.params[1]) # parameters as [scale, slope, gamma, lambda]
  85. fit_dict[key]['data'] = data_matrix
  86. fit_dict[key]['metrics'] = metrics
  87. fit_dict[key]['options'] = options
  88. return fit_dict
  89. def _to_df(self):
  90. fit_dict = self.fit()
  91. df = pd.DataFrame(columns=['hue_angle', 'ntrial', 'JND', 'JND_SD', 'PSE', 'PSE_SD'])
  92. for key, fit in fit_dict.items():
  93. if len(key) > 1:
  94. key = key[1]
  95. df.loc[key, 'hue_angle'] = fit['hue_angle']
  96. df.loc[key, 'ntrial'] = fit['ntrial']
  97. df.loc[key, 'PSE'] = fit['metrics']['MAP'][0]
  98. df.loc[key, 'JND'] = fit['metrics']['MAP'][1]
  99. df.loc[key, 'PSE_SD'] = fit['metrics']['SD'][0]
  100. df.loc[key, 'JND_SD'] = fit['metrics']['SD'][1]
  101. df = df.reset_index()
  102. return df
  103. def plot_pf_curve(self):
  104. """
  105. Plot PFs from nonlinear least squares fitting results.
  106. """
  107. fit_dict = self.fit()
  108. num = len(fit_dict)
  109. if self.ylabel == 'response':
  110. ylabel = 'Prob. Response as larger hue angles'
  111. n_col = int(num / 2)
  112. xlim = [-18, 18]
  113. ylim = [-0.05, 1.05]
  114. else:
  115. ylabel = 'Prob. Correct judge'
  116. n_col = int(num / 4)
  117. xlim = [0, 18]
  118. ylim = [0.45, 1.05]
  119. if num == 1:
  120. fig, axes = plt.subplots(num, 1, figsize=plot_config['figsize'])
  121. else:
  122. fig, axes = plt.subplots(2, n_col, figsize=plot_config['figsize'])
  123. for idx, key in enumerate(fit_dict):
  124. this_fit = fit_dict[key]
  125. hue_angle = this_fit['hue_angle']
  126. ntrial = this_fit['ntrial']
  127. data = this_fit['data']
  128. options = this_fit['options']
  129. # Determine which values to use for vector of parameters
  130. param_guess = np.zeros(4)
  131. counter = 0
  132. for keys in options['param_free']:
  133. if keys is True:
  134. param_guess[counter] = this_fit['metrics']['MAP'][counter]
  135. elif keys is False:
  136. param_guess[counter] = options['param_ests'][counter]
  137. counter += 1
  138. if num == 1:
  139. ax = axes
  140. if this_fit['ylabel'] == 'response':
  141. ax = axes.flatten()[idx]
  142. label = None
  143. color = color4plot(hue_angle)[0]
  144. elif this_fit['ylabel'] == 'correct':
  145. ax = axes.flatten()[int(np.floor(idx / 2))]
  146. if key[1] < 0:
  147. label = 'minus'
  148. color = 'coral'
  149. else:
  150. label = 'plus'
  151. color = 'skyblue'
  152. # color_code = color4plot(hue_angle)[0]
  153. # hue_angles = np.array([float(k) for k in fit_dat.keys()])
  154. # color = colorcodes[int((key - self.first_angle) / self.first_angle / 2)]
  155. for i in range(data[:, 0].shape[0]):
  156. ax.scatter(data[i, 0],
  157. data[i, 1] / data[i, 2],
  158. color=color,
  159. s=data[i, 2] * 2,
  160. alpha=0.5,
  161. zorder=5,
  162. marker='o')
  163. # Generate smooth curve from fitted function
  164. # x_max = data[:, 0].max()
  165. # x_min = data[:, 0].min()
  166. x_min = xlim[0]
  167. x_max = xlim[1]
  168. x_est = np.linspace(x_min, x_max, 50)
  169. y_pred = _psyfunction(x_est,
  170. param_guess[0],
  171. param_guess[1],
  172. param_guess[2],
  173. param_guess[3],
  174. options['sigmoid_type'])
  175. ax.plot(x_est, y_pred, '-', color=color, label=label) # plot fitted curve
  176. # ax.axhline(y=thre_y, color='grey', linestyle='dashed', linewidth=1, zorder=1, alpha=0.5)
  177. # ax.axvline(x=this_fit['metrics']['threshold'], color='grey', linestyle='dashed', linewidth=1, zorder=1, alpha=0.5)
  178. ax.plot([x_min, this_fit['metrics']['threshold']],
  179. [options['threshold'], options['threshold']],
  180. color='grey',
  181. linestyle='dotted',
  182. linewidth=1,
  183. zorder=1,
  184. alpha=0.8)
  185. ax.plot([this_fit['metrics']['threshold'], this_fit['metrics']['threshold']],
  186. [0, options['threshold']],
  187. color='grey',
  188. linestyle='dotted',
  189. linewidth=1,
  190. zorder=1,
  191. alpha=0.8)
  192. # ssq = np.round(this_dat['ssq'], decimals=3) # sum-squared error
  193. # ax.text(3.5, 0.55, 'ssq = ' + str(ssq), fontsize=plot_config['tick_size'])
  194. ax.set_title('hue_angle: ' + str(hue_angle) + ', ' + '%dtrials' % ntrial,
  195. fontsize=plot_config['title_fontsize'])
  196. ax.set_xlim(xlim)
  197. ax.set_ylim(ylim)
  198. ax.tick_params(axis='both', which='major', labelsize=plot_config['tick_size'])
  199. if num == 1:
  200. x_ax = ax
  201. y_ax = ax
  202. elif num == 2:
  203. x_ax = axes[-1]
  204. y_ax = axes[0]
  205. else:
  206. x_ax = axes[-1, :]
  207. y_ax = axes[:, 0]
  208. plt.setp(x_ax, xlabel='Hue Angle')
  209. plt.setp(y_ax, ylabel=ylabel)
  210. plt.setp(ax.get_xticklabels(), fontsize=plot_config['tick_size'])
  211. plt.setp(ax.get_yticklabels(), fontsize=plot_config['tick_size'])
  212. # fig.suptitle(self.sub[0:2] + '_' + str(ntrial) + 'trials', fontsize=plot_config['title_fontsize'])
  213. fig.suptitle(str(ntrial) + ' trials', fontsize=plot_config['title_fontsize'])
  214. plt.legend()
  215. plt.show()
  216. # s05_ll = LoadData('s05', data_path='data', sel_par=['LL_2x2']).read_data()
  217. # FitPf = FitPf_bayes(s05_ll, ylabel='correct', func='norm', params=[[None, None], [True, True, False, False]])
  218. # fit_dict = FitPf.fit()
  219. # fit_df = FitPf._to_df()
  220. # FitPf.plot_pf_curve()