ott_RW_V_base_asymm.m 975 B

12345678910111213141516171819202122232425262728293031323334353637
  1. function [LH, probSpike, V, mean_predictedSpikes, RPE] = ott_RW_V_base_asymm(startValues, spikeCounts, rewards, timeLocked)
  2. alphaPPE = startValues(1);
  3. alphaNPE = startValues(2);
  4. slope = startValues(3);
  5. intercept = startValues(4);
  6. Vinit = alphaPPE / (alphaPPE + alphaNPE);
  7. trials = length(rewards);
  8. V = zeros(trials + 1, 1);
  9. RPE = zeros(trials, 1);
  10. V(1) = Vinit;
  11. % Call learning rule
  12. for t = 1:trials
  13. RPE(t) = rewards(t) - V(t);
  14. if RPE(t) >= 0
  15. V(t + 1) = V(t) + alphaPPE*RPE(t);
  16. else
  17. V(t + 1) = V(t) + alphaNPE*RPE(t);
  18. end
  19. end
  20. rateParam = exp(slope*V(1:trials) + intercept);
  21. probSpike = poisspdf(spikeCounts, rateParam(timeLocked)); % mask rateParam to exclude trials where the animal didn't lick fast enough
  22. mean_predictedSpikes = rateParam(timeLocked);
  23. V = V(1:trials);
  24. V = V(timeLocked);
  25. RPE = RPE(timeLocked);
  26. if any(isinf(log(probSpike)))
  27. LH = 1e9;
  28. else
  29. LH = -1 * sum(log(probSpike));
  30. end