classdef FitResult < handle
properties
numResults
lambda
numCoeffs
fitType
b
dev
AIC
BIC
logLL
stats
configs
configNames
neuronNumber
neuralSpikeTrain
covLabels
uniqueCovLabels
indicesToUniqueLabels
numHist
histObjects
ensHistObjects
flatMask
Z
U
X
Residual
invGausStats
KSStats
plotParams
XvalData
XvalTime
validation
minTime
maxTime
end
properties (Constant,Hidden)
colors={'b','g','r','c','m','y','k'};
end
methods
function fitObj=FitResult(spikeObj,covLabels,numHist,histObjects,ensHistObj,lambda,b, dev, stats,AIC,BIC,logLL, configColl,XvalData,XvalTime,distribution)
if(nargin< 14)
XvalTime =[];
end
if(nargin<13)
XvalData =[];
end
if(isa(spikeObj,'cell'))
for i=1:length(spikeObj)
if(isnumeric(spikeObj{i}.name))
nNumber(i) =spikeObj{i}.name;
else
nNumber(i) = str2double(spikeObj{i}.name(~isletter(spikeObj{i}.name)));
end
minTime(i)=spikeObj{i}.minTime;
maxTime(i)=spikeObj{i}.maxTime;
end
nNumber = unique(nNumber);
minTime = unique(minTime);
maxTime = unique(maxTime);
if(length(nNumber)>1)
error('Can only have a FitResults with spike trains from a single neuron');
end
if(length(minTime)>1 || length(maxTime)>1)
error('Spike Trains are of different lengths');
end
elseif(isa(spikeObj,'nspikeTrain'))
if(isnumeric(spikeObj.name))
nNumber =spikeObj.name;
else
nNumber = str2double(spikeObj.name(~isletter(spikeObj.name)));
end
minTime=spikeObj.minTime;
maxTime=spikeObj.maxTime;
end
fitObj.neuronNumber = nNumber;
fitObj.neuralSpikeTrain = spikeObj;
fitObj.minTime = minTime;
fitObj.maxTime = maxTime;
fitObj.numResults = 0;
fitObj.configs = configColl;
fitObj.configNames = configColl.getConfigNames;
fitObj.covLabels=covLabels;
fitObj.uniqueCovLabels= getUniqueLabels(covLabels);
fitObj.mapCovLabelsToUniqueLabels;
fitObj.numHist=numHist;
fitObj.histObjects = histObjects;
fitObj.ensHistObjects = ensHistObj;
fitObj.addParamsToFit(fitObj.neuronNumber,lambda,b, dev, stats,AIC,BIC,logLL,configColl);
fitObj.Z =[];
fitObj.U =[];
fitObj.X =[];
fitObj.Residual =[];
fitObj.KSStats.xAxis =[];
fitObj.KSStats.KSSorted =[];
fitObj.KSStats.ks_stat =[];
fitObj.invGausStats.rhoSig=[];
fitObj.invGausStats.confBoundSig=[];
fitObj.plotParams = [];
fitObj.XvalData = XvalData;
fitObj.XvalTime = XvalTime;
fitObj.fitType = distribution;
end
function fitObj = setNeuronName(fitObj,name)
fitObj.neuronNumber = name;
end
function mFitRes = mergeResults(fitObj,newFitObj)
if(isa(newFitObj,'FitResult'))
if(fitObj.neuronNumber ==newFitObj.neuronNumber)
spikeObj = fitObj.neuralSpikeTrain;
covLabels = fitObj.covLabels(1:fitObj.numResults);
covLabels((fitObj.numResults+1):(fitObj.numResults+newFitObj.numResults)) = newFitObj.covLabels(1:newFitObj.numResults);
numHist = fitObj.numHist(1:fitObj.numResults);
numHist((fitObj.numResults+1):(fitObj.numResults+newFitObj.numResults)) = newFitObj.numHist(1:newFitObj.numResults);
histObjects=fitObj.histObjects(1:fitObj.numResults);
histObjects((fitObj.numResults+1):(fitObj.numResults+newFitObj.numResults)) = newFitObj.histObjects(1:newFitObj.numResults);
ensHistObjects=fitObj.ensHistObjects(1:fitObj.numResults);
ensHistObjects((fitObj.numResults+1):(fitObj.numResults+newFitObj.numResults)) = newFitObj.ensHistObjects(1:newFitObj.numResults);
b=fitObj.b(1:fitObj.numResults);
b((fitObj.numResults+1):(fitObj.numResults+newFitObj.numResults)) = newFitObj.b(1:newFitObj.numResults);
dev = [fitObj.dev newFitObj.dev];
AIC = [fitObj.AIC newFitObj.AIC];
BIC = [fitObj.BIC newFitObj.BIC];
logLL = [fitObj.logLL newFitObj.logLL];
stats=fitObj.stats(1:fitObj.numResults);
stats((fitObj.numResults+1):(fitObj.numResults+newFitObj.numResults)) = newFitObj.stats(1:newFitObj.numResults);
lambda = fitObj.lambda.merge(newFitObj.lambda);
for i=1:fitObj.numResults
config{i}=fitObj.configs.getConfig(i);
end
offset=fitObj.numResults;
for i=1:newFitObj.numResults
config{i+offset}=newFitObj.configs.getConfig(i);
end
configColl= ConfigColl(config);
XvalData = [fitObj.XvalData newFitObj.XvalData];
XvalTime = [fitObj.XvalTime newFitObj.XvalTime];
distribution=fitObj.fitType(1:fitObj.numResults);
distribution((fitObj.numResults+1):(fitObj.numResults+newFitObj.numResults)) = newFitObj.fitType(1:newFitObj.numResults);
tempZ = zeros(length(fitObj.Z),size(newFitObj.Z,2));
tempU = zeros(length(fitObj.U),size(newFitObj.U,2));
tempZ(1:length(newFitObj.Z),:) = newFitObj.Z;
tempU(1:length(newFitObj.U),:) = newFitObj.U;
Z=[fitObj.Z tempZ];
U=[fitObj.U tempU];
[X,rhoSig,confBoundSig] = Analysis.computeInvGausTrans(Z);
M=fitObj.Residual.merge(newFitObj.Residual);
origLength = size(fitObj.KSStats.xAxis,1);
currLength = size(newFitObj.KSStats.xAxis,1);
if(currLength~=origLength)
newX = fitObj.KSStats.xAxis;
oldX = newFitObj.KSStats.xAxis;
oldY = newFitObj.KSStats.KSSorted;
y = interp1(oldX,oldY,newX(:,1),'spline','extrap');
xAxis = [fitObj.KSStats.xAxis newX(:,1)];
KSSorted = [fitObj.KSStats.KSSorted y];
else
xAxis = [fitObj.KSStats.xAxis newFitObj.KSStats.xAxis];
KSSorted = [fitObj.KSStats.KSSorted newFitObj.KSStats.KSSorted];
end
ks_stat = [fitObj.KSStats.ks_stat newFitObj.KSStats.ks_stat];
mFitRes=FitResult(spikeObj,covLabels,numHist,histObjects,ensHistObjects,lambda,b, dev, stats,AIC,BIC,logLL,configColl,XvalData,XvalTime,distribution);
mFitRes.setKSStats(Z,U, xAxis, KSSorted, ks_stat);
mFitRes.setInvGausStats(X,rhoSig,confBoundSig);
mFitRes.setFitResidual(M);
elseif(isa(newFitObj,'cell'))
if(isa(newFitObj{1},'FitResult'))
for i=1:length(newFitObj)
if(i==1)
mFitRes = fitObj.mergeResults(newFitObj{i});
else
mFitRes = mFitRes.mergeResults(newFitObj{i});
end
end
end
end
end
end
function subsetFit = getSubsetFitResult(fitObj,subfits)
if(and(min(subfits)>0,max(subfits)<=fitObj.numResults))
spikeObj = fitObj.neuralSpikeTrain;
covLabels = fitObj.covLabels(subfits);
numHist = fitObj.numHist(subfits);
histObjects = fitObj.histObjects(subfits);
ensHistObj = fitObj.ensHistObjects(subfits);
lambda = fitObj.lambda.getSubSignal(subfits);
b = fitObj.b(subfits);
dev = fitObj.dev(subfits);
stats = fitObj.stats(subfits);
AIC = fitObj.AIC(subfits);
BIC = fitObj.BIC(subfits);
logLL = fitObj.logLL(subfits);
configColl = fitObj.configs.getSubsetConfigs(subfits);
XvalData = fitObj.XvalData;
XvalTime = fitObj.XvalTime;
distribution = fitObj.fitType;
subsetFit=FitResult(spikeObj,covLabels,numHist,histObjects,ensHistObj,lambda,b, dev, stats,AIC,BIC,logLL,configColl,XvalData,XvalTime,distribution);
Z = fitObj.Z(:,subfits);
U = fitObj.U(:,subfits);
X = fitObj.X(:,subfits);
xAxis= fitObj.KSStats.xAxis(:,subfits);
KSSorted=fitObj.KSStats.KSSorted(:,subfits);
ks_stat=fitObj.KSStats.ks_stat(subfits);
rhoSig=fitObj.invGausStats.rhoSig.getSubSignal(subfits);
confBoundSig=fitObj.invGausStats.confBoundSig;
M = fitObj.Residual.getSubSignal(subfits);
subsetFit.setKSStats(Z,U, xAxis, KSSorted, ks_stat);
subsetFit.setInvGausStats(X,rhoSig,confBoundSig);
subsetFit.setFitResidual(M);
end
end
function addParamsToFit(fitObj,neuronNum,lambda,b, dev, stats,AIC,BIC,logLL,configColl)
if(fitObj.neuronNumber==neuronNum)
if(isa(lambda,'cell'))
newLambda=lambda{1};
for i=2:length(lambda)
newLambda = newLambda.merge(lambda{i});
end
elseif(isa(lambda,'Covariate')||isa(lambda,'SignalObj'))
newLambda = lambda;
end
numNewResults = newLambda.dimension;
if(nargin<8)
configColl = cell(1,numNewResults);
end
if(numNewResults==1)
fitObj.b{fitObj.numResults+1} = b{1};
fitObj.dev(fitObj.numResults+1) = dev;
fitObj.stats{fitObj.numResults+1}= stats{1};
if(nargin<7)
fitObj.AIC(fitObj.numResults+1) = 2*length(b)+dev;
fitObj.BIC(fitObj.numResults+1) = length(b)*log(length(newLambda.time))+dev;
delta = 1/newLambda.sampleRate;
fitObj.logLL(fitObj.numResults+1) = sum(y.*log(data*delta)+(1-y).*(1-newLambda.data*delta));
else
fitObj.AIC(fitObj.numResults+1) = AIC;
fitObj.BIC(fitObj.numResults+1) = BIC;
fitObj.logLL(fitObj.numResults+1) = logLL;
end
fitObj.numCoeffs(fitObj.numResults+1) = length(b);
else
for i=1:numNewResults
fitObj.b{fitObj.numResults+i} = b{i};
fitObj.dev(fitObj.numResults+i) = dev(i);
fitObj.stats{fitObj.numResults+i}= stats{i};
if(nargin<7)
fitObj.AIC(fitObj.numResults+i) = 2*length(b{i})+dev(i);
fitObj.BIC(fitObj.numResults+i) = length(b{i})*log(length(newLambda.time))+dev(i);
delta=fitObj.neuralSpikeTrain.sampleRate;
y=fitObj.neuralSpikeTrain.getSigRep.dataToMatrix;
fitObj.logLL(fitObj.numResults+i)= sum(y.*log(newLambda.data*delta)+(1-y).*(1-newLambda.data*delta));
else
fitObj.AIC(fitObj.numResults+i) = AIC(i);
fitObj.BIC(fitObj.numResults+i) = BIC(i);
fitObj.logLL(fitObj.numResults+i)= logLL(i);
end
fitObj.numCoeffs(fitObj.numResults+i) = length(b{i});
end
end
if(fitObj.numResults ==0)
fitObj.lambda = newLambda;
else
fitObj.lambda = fitObj.lambda.merge(newLambda);
end
fitObj.numResults = fitObj.numResults+numNewResults;
dataLabels = cell(1,fitObj.numResults);
for i=1:fitObj.numResults
dataLabels{i} = strcat('\lambda_{',num2str(i),'}');
end
fitObj.lambda.setDataLabels(dataLabels);
fitObj.configs.addConfig(configColl);
fitObj.configNames = fitObj.configs.getConfigNames;
else
error('Neuron number does not match');
end
end
function [lambda, logLL] = computeValLambda(fitObj)
lambdaData = zeros(length(fitObj.XvalTime{1}),fitObj.numResults);
for i=1:fitObj.numResults
lambdaData(:,i) = fitObj.evalLambda(i,fitObj.XvalData{i});
end
lambda=Covariate(fitObj.XvalTime{1},lambdaData,...
'\lambda(t)',fitObj.lambda.xlabelval,...
fitObj.lambda.xunits,'Hz',fitObj.lambda.dataLabels);
delta = 1/lambda.sampleRate;
y=fitObj.neuralSpikeTrain.getSigRep.dataToMatrix;
logLL =sum(y.*log(lambda.data*delta)+(1-y).*(1-lambda.data*delta));
end
function mapCovLabelsToUniqueLabels(fitObj)
flatMask = zeros(length(fitObj.uniqueCovLabels),length(fitObj.covLabels));
for j=1:length(fitObj.covLabels)
currLabels = fitObj.covLabels{j};
index=zeros(1,length(currLabels));
for i=1:length(currLabels)
idx = find(strcmp(currLabels{i}, fitObj.uniqueCovLabels),1,'first');
if(isempty(idx))
error('FitResult:UnknownCovariateLabel',...
'Unable to map covariate label "%s" to unique labels.',currLabels{i});
end
index(i)=idx;
end
fitObj.indicesToUniqueLabels{j} = index;
flatMask(index,j) = 1;
end
fitObj.flatMask = flatMask;
end
function p=getPlotParams(fitObj)
if(isempty(fitObj.plotParams))
fitObj.computePlotParams;
end
p=fitObj.plotParams;
end
function plotValidation(fitObj)
if(~isempty(fitObj.validation))
fitObj.validation.plotResults;
else
display('Validation Data not available to plot');
end
end
function answer = isValDataPresent(fitObj)
answer = 0;
if(~isempty(fitObj.XvalTime) && ~isempty(fitObj.XvalData))
for i=1:length(fitObj.XvalTime)
currTime = fitObj.XvalTime{i};
if(~isempty(currTime))
if(currTime(end)-currTime(1)>0)
answer =1;
break;
end
end
end
end
end
function lambdaData = evalLambda(fitObj,lambdaIndex,newData)
if(lambdaIndex>0 && lambdaIndex <= fitObj.numResults)
b=fitObj.b{lambdaIndex};
if(isempty(newData))
[rows,~] = size(newData);
baseline=ones(rows,1);
lambdaData = exp(b(1)*baseline);
else
if(isa(newData,'double'))
baseline=ones(length(newData),1);
[~,columns] = size(newData);
if(length(b)>=1)
lambdaData = exp(newData*b(1:end));
if(strcmp(fitObj.fitType{lambdaIndex},'poisson'))
else
lambdaData = lambdaData./(1+lambdaData);
end
end
lambdaData = lambdaData*fitObj.neuralSpikeTrain.sampleRate;
elseif(isa(newData,'cell'))
runSum=0;
for i=1:(length(newData))
if(i<=length(b))
runSum = runSum+b(i)*newData{i};
end
end
if(strcmp(fitObj.fitType{lambdaIndex},'poisson'))
lambdaData = exp(runSum);
lambdaData = lambdaData*fitObj.neuralSpikeTrain.sampleRate;
else
lambdaData = exp(runSum);
lambdaData = lambdaData./(1+lambdaData);
lambdaData = lambdaData*fitObj.neuralSpikeTrain.sampleRate;
end
else
error('New data must be cell or a matrix');
end
end
else
error('Index into fit params is incorrect');
end
end
function computePlotParams(fitObj,fitNum)
if(nargin<2)
fitNum = 1:fitObj.numResults;
end
index=find(sum(fitObj.flatMask,2)>0);
sigIndex=zeros(length(index),length(fitNum));
bAct = nan(length(index),length(fitNum));
seAct= nan(length(index),length(fitNum));
for i=fitNum
criteria = find(fitObj.stats{i}.se'<100);
indicesForFit = fitObj.indicesToUniqueLabels{i};
bVals = fitObj.b{i}(criteria);
bAct(indicesForFit(criteria),i) = bVals;
seVals = fitObj.stats{i}.se(criteria)';
seAct(indicesForFit(criteria),i)= seVals;
temp = sign([bAct(:,i)-seAct(:,i) bAct(:,i)+seAct(:,i)]);
productOfSigns = temp(:,1).*temp(:,2);
sIndex=and(productOfSigns>0,seAct(:,i)~=0);
sigIndex(:,i)=sIndex;
end
fitObj.plotParams.bAct = bAct;
fitObj.plotParams.seAct= seAct;
fitObj.plotParams.sigIndex = sigIndex;
fitObj.plotParams.xLabels = cell(length(index),1);
fitObj.plotParams.xLabels = fitObj.uniqueCovLabels;
tempVal =sum(fitObj.flatMask,2);
fitObj.plotParams.numResultsCoeffPresent =tempVal(index);
end
function [coeffIndex, epochId,numEpochs] = getCoeffIndex(fitObj,fitNum,sortByEpoch)
if(nargin<3 || isempty(sortByEpoch))
sortByEpoch=0;
end
if(nargin<2 || isempty(fitNum))
fitNum = 1:fitObj.numResults;
end
if(isempty(fitObj.plotParams))
fitObj.computePlotParams;
end
[histIndex, epochId] = fitObj.getHistIndex(fitNum,sortByEpoch);
allIndex = 1:length(fitObj.uniqueCovLabels);
nonHistIndex = setdiff(allIndex,histIndex);
nonNANIndex= allIndex;
actCoeffIndex = nonHistIndex(ismember(nonHistIndex, nonNANIndex));
allCoeffTerms = fitObj.uniqueCovLabels(actCoeffIndex);
epochStartInd=regexp(allCoeffTerms,'_\{\d*\}','start');
epochEndInd=regexp(allCoeffTerms,'_\{\d*\}','end');
allCoeffIndex = [];
nonEpochIndex=[];
epochsExist =0;
for i=1:length(allCoeffTerms)
if(~isempty(allCoeffTerms{i}))
allCoeffIndex = [allCoeffIndex i];
if(~isempty(epochStartInd{i}))
epochsExist=1;
actStart = epochStartInd{i}+2;
actEnd = epochEndInd{i}-1;
numEpoch(i) = str2num(allCoeffTerms{i}(actStart:actEnd));
else
nonEpochIndex = [nonEpochIndex i];
numEpoch(i) = 0;
end
end
end
if(epochsExist && ~sortByEpoch)
totalEpochs = unique(numEpoch);
coeffIndex = nonEpochIndex;
if(nargout>1)
epochId=zeros(size(nonEpochIndex));
end
for i=1:length(totalEpochs)
if(totalEpochs(i)~=0)
coeffIndex = [coeffIndex, find(numEpoch==totalEpochs(i))];
if(nargout>1)
epochId = [epochId, totalEpochs(i)*ones(size(find(numEpoch==totalEpochs(i))))];
end
end
end
coeffIndex = actCoeffIndex(coeffIndex);
elseif(epochsExist && sortByEpoch)
coeffIndex = actCoeffIndex(allCoeffIndex);
if(nargout>1)
epochId = numEpoch;
end
else
coeffIndex = actCoeffIndex(allCoeffIndex);
if(nargout>1)
epochId = zeros(size(allCoeffIndex));
end
end
nonNANIndex = allIndex;
coeffIndex = coeffIndex(ismember(coeffIndex, nonNANIndex));
if(nargout>2)
numEpochs = length(unique(epochId));
end
end
function h=plotCoeffsWithoutHistory(fitObj,fitNum,sortByEpoch,plotSignificance)
if(nargin<4 || isempty(plotSignificance))
plotSignificance=1;
end
if(nargin<3 || isempty(sortByEpoch))
sortByEpoch = 0;
end
if(nargin<2 || isempty(fitNum))
fitNum = 1:fitObj.numResults;
end
if(isempty(fitObj.plotParams))
fitObj.computePlotParams;
end
coeffIndex = fitObj.getCoeffIndex(fitNum,sortByEpoch);
h=fitObj.plotCoeffs([],fitNum,[],plotSignificance,coeffIndex);
end
function [histIndex, epochId,numEpochs] = getHistIndex(fitObj,fitNum,sortByEpoch)
if(nargin<3 || isempty(sortByEpoch))
sortByEpoch = 0;
end
if(nargin<2 || isempty(fitNum))
fitNum = 1:fitObj.numResults;
end
if(isempty(fitObj.plotParams))
fitObj.computePlotParams;
end
allHistTerms = regexp(fitObj.uniqueCovLabels,'^[\w*');
epochStartInd=regexp(fitObj.uniqueCovLabels,'\]_\{\d*\}','start');
epochEndInd=regexp(fitObj.uniqueCovLabels,'\]_\{\d*\}','end');
allHistIndex = [];
epochsExist =0;
for i=1:length(allHistTerms)
if(~isempty(allHistTerms{i}))
allHistIndex = [allHistIndex i];
if(~isempty(epochStartInd{i}))
epochsExist=1;
actStart = epochStartInd{i}+3;
actEnd = epochEndInd{i}-1;
numEpoch(i) = str2num(fitObj.uniqueCovLabels{i}(actStart:actEnd));
end
end
end
if(epochsExist && ~sortByEpoch)
totalEpochs = unique(numEpoch);
histIndex = [];
if(nargout>1)
epochId=[];
end
for i=1:length(totalEpochs)
histIndex = [histIndex, find(numEpoch==totalEpochs(i))];
if(nargout>1)
epochId = [epochId, totalEpochs(i)*ones(size(find(numEpoch==totalEpochs(i))))];
end
end
elseif(epochsExist && sortByEpoch)
histIndex = allHistIndex;
if(nargout>1)
epochId = numEpoch;
end
else
histIndex = allHistIndex;
if(nargout>1)
epochId = zeros(size(allHistIndex));
end
end
if(nargout>2)
numEpochs = length(unique(epochId));
end
end
function [coeffMat, labels, SEMat] = getCoeffs(fitObj, fitNum)
if(nargin<2 || isempty(fitNum))
fitNum =1:fitObj.numResults;
end
sortByEpoch = 0;
[coeffIndex, epochId, numEpochs] = fitObj.getCoeffIndex(fitNum,sortByEpoch);
epochNums = unique(epochId);
coeffStrings = fitObj.uniqueCovLabels(coeffIndex);
baseStringEndIndex =regexp(coeffStrings,'_\{\d*\}','start');
for i=1:length(baseStringEndIndex)
if(~isempty(baseStringEndIndex{i}))
baseStrings{i} = coeffStrings{i}(1:baseStringEndIndex{i}-1);
else
baseStrings{i} = coeffStrings{i};
end
end
uniqueCoeffs = unique(baseStrings);
for i=1:length(uniqueCoeffs)
coeffStrIndex{i} = coeffIndex(strcmp(baseStrings,uniqueCoeffs{i}));
if(min(epochId)==0)
epochIndices{i} = epochId(strcmp(baseStrings,uniqueCoeffs{i}))+1;
else
epochIndices{i} = epochId(strcmp(baseStrings,uniqueCoeffs{i}));
end
end
coeffIndMat= nan(length(uniqueCoeffs),numEpochs);
labels = cell(size(coeffIndMat));
for i=1:length(uniqueCoeffs)
coeffIndMat(i,epochIndices{i}) = coeffStrIndex{i};
labels(i,epochIndices{i}) = fitObj.uniqueCovLabels(coeffStrIndex{i});
end
if(length(fitNum)>1)
coeffMat = nan(size(coeffIndMat,1),size(coeffIndMat,2), length(fitNum));
SEMat = nan(size(coeffIndMat,1),size(coeffIndMat,2), length(fitNum));
for i=1:length(fitNum)
for j=1:length(uniqueCoeffs)
bTemp=fitObj.plotParams.bAct(coeffStrIndex{j},i);
seTemp=fitObj.plotParams.seAct(coeffStrIndex{j},i);
coeffMat(j,epochIndices{j},i) = bTemp';
SEMat(j,epochIndices{j},i) = seTemp';
end
end
else
coeffMat = nan(size(coeffIndMat,1),size(coeffIndMat,2));
SEMat = nan(size(coeffIndMat,1),size(coeffIndMat,2));
for j=1:length(uniqueCoeffs)
bTemp=fitObj.plotParams.bAct(coeffStrIndex{j},fitNum);
seTemp = fitObj.plotParams.seAct(coeffStrIndex{j},fitNum);
coeffMat(j,epochIndices{j}) = bTemp';
SEMat(j,epochIndices{j}) = seTemp';
end
end
end
function [histMat, labels, SEMat] = getHistCoeffs(fitObj,fitNum)
if(nargin<2 || isempty(fitNum))
fitNum =1:fitObj.numResults;
end
sortByEpoch = 0;
[histIndex, epochId, numEpochs] = fitObj.getHistIndex(fitNum,sortByEpoch);
epochNums = unique(epochId);
histcoeffStrings = fitObj.uniqueCovLabels(histIndex);
baseStringEndIndex =regexp(histcoeffStrings,'_\{\d*\}','start');
baseStrings = cell(length(baseStringEndIndex),1);
for i=1:length(baseStringEndIndex)
if(~isempty(baseStringEndIndex{i}))
baseStrings{i} = histcoeffStrings{i}(1:baseStringEndIndex{i}-1);
else
baseStrings{i} = histcoeffStrings{i};
end
end
uniqueCoeffs = unique(baseStrings);
for i=1:length(uniqueCoeffs)
histcoeffStrIndex{i} = histIndex(strcmp(baseStrings,uniqueCoeffs{i}));
if(min(epochId)==0)
epochIndices{i} = epochId(strcmp(baseStrings,uniqueCoeffs{i}))+1;
else
epochIndices{i} = epochId(strcmp(baseStrings,uniqueCoeffs{i}));
end
end
histcoeffIndMat= nan(length(uniqueCoeffs),numEpochs);
labels = cell(size(histcoeffIndMat));
for i=1:length(uniqueCoeffs)
histcoeffIndMat(i,epochIndices{i}) = histcoeffStrIndex{i};
labels(i,epochIndices{i}) = fitObj.uniqueCovLabels(histcoeffStrIndex{i});
end
if(length(fitNum)>1)
histMat = nan(size(histcoeffIndMat,1),size(histcoeffIndMat,2), length(fitNum));
SEMat = nan(size(histcoeffIndMat,1),size(histcoeffIndMat,2), length(fitNum));
for i=fitNum
for j=1:length(uniqueCoeffs)
bTemp=fitObj.plotParams.bAct(histcoeffStrIndex{j},i);
seTemp = fitObj.plotParams.seAct(histcoeffStrIndex{j},i);
histMat(j,epochIndices{j},i) = bTemp';
SEMat(j,epochIndices{j},i) = seTemp';
end
end
else
histMat = nan(size(histcoeffIndMat,1),size(histcoeffIndMat,2));
SEMat = nan(size(histcoeffIndMat,1),size(histcoeffIndMat,2));
for j=1:length(uniqueCoeffs)
bTemp=fitObj.plotParams.bAct(histcoeffStrIndex{j},fitNum);
seTemp=fitObj.plotParams.seAct(histcoeffStrIndex{j},fitNum);
histMat(j,epochIndices{j}) = bTemp';
SEMat(j,epochIndices{j}) = seTemp';
end
end
end
function h=plotHistCoeffs(fitObj,fitNum,sortByEpoch,plotSignificance)
if(nargin<4 || isempty(plotSignificance))
plotSignificance=1;
end
if(nargin<3 || isempty(sortByEpoch))
sortByEpoch=0;
end
if(nargin<2 || isempty(fitNum))
fitNum = 1:fitObj.numResults;
end
if(isempty(fitObj.plotParams))
fitObj.computePlotParams;
end
histIndex = fitObj.getHistIndex(fitNum,sortByEpoch);
h=fitObj.plotCoeffs([],fitNum,[],plotSignificance,histIndex);
end
function h=plotCoeffs(fitObj,handle,fitNum,plotProps,plotSignificance,subIndex)
if(nargin<5 || isempty(plotSignificance))
plotSignificance = 1;
end
if(nargin<4 || isempty(plotProps))
plotProps = [];
end
if(nargin<3 || isempty(fitNum))
fitNum = 1:fitObj.numResults;
end
if(nargin<2 || isempty(handle))
handle=gca;
end
if(isempty(fitObj.plotParams))
fitObj.computePlotParams;
end
if(nargin<6 || isempty(subIndex))
subIndex = [fitObj.getHistIndex, fitObj.getCoeffIndex];
end
bAct = fitObj.getPlotParams.bAct(subIndex,fitNum);
seAct= fitObj.getPlotParams.seAct(subIndex,fitNum);
sigIndex=fitObj.getPlotParams.sigIndex(subIndex,fitNum);
if(~isempty(plotProps))
for i=1:length(fitNum)
h(i)=errorbar(handle,1:length(subIndex),bAct(:,i),seAct(:,i),plotProps{i}); hold on;
set(h(i), 'LineStyle', 'none', 'Marker', '.');
currColor = get(h(i),'Color');
set(h(i),'MarkerEdgeColor',currColor,'MarkerFaceColor',currColor);
end
else
Xaxis=repmat(1:length(bAct(:,1)),[length(bAct(1,:)) 1]);
h=errorbar(handle,Xaxis',bAct,seAct,'.');
set(h, 'LineStyle', 'none', 'Marker', '.');
for n=1:length(h)
currColor = get(h(n),'Color');
set(h(n),'MarkerEdgeColor',currColor,'MarkerFaceColor',currColor);
end
end
hold on;
if(plotSignificance==1)
v=axis;
vdiff = .8*v(4);
for i=1:length(fitNum)
plot(handle,find(sigIndex(:,i)==1),vdiff*ones(length(find(sigIndex(:,i)==1)),1)-i*.1,strcat('*',FitResult.colors{mod(i-1,length(FitResult.colors))+1})); hold on;
end
end
ylabel('GLM Fit Coefficients','Interpreter','none');
xtickLabels = fitObj.getPlotParams.xLabels(subIndex);
xticks = 1:(length(xtickLabels));
set(handle,'xtick',xticks,'xtickLabel',xtickLabels,'FontSize',6);
if(max(fitObj.numCoeffs)>=1)
xticklabel_rotate([],90,[],'Fontsize',10);
end
h_legend=legend(handle,fitObj.lambda.dataLabels(fitNum),'Location','NorthEast');
set(h_legend,'FontSize',14)
pos = get(h_legend,'position');
set(h_legend, 'position',[pos(1)+.05 pos(2) pos(3:4)]);
title({'GLM Coefficients with 95% CIs (* p<0.05)'},'FontWeight','bold',...
'FontSize',11,...
'FontName','Arial');
set(gca,'FontName', 'Arial' );
set(gca, ...
'TickLength' , [.02 .02] , ...
'YGrid' , 'on' , ...
'LineWidth' , 1 );
hx=get(gca,'XLabel'); hy=get(gca,'YLabel');
set([hx hy],'FontName', 'Arial','FontSize',12,'FontWeight','bold');
end
function h=plotResults(fitObj)
scrsz = get(0,'ScreenSize');
h=figure('OuterPosition',[scrsz(3)*.01 scrsz(4)*.04 scrsz(3)*.98 scrsz(4)*.95]);
subplot(2,4,[1 2]); fitObj.KSPlot;
ht=text(.45, .95,strcat('Neuron:',num2str(fitObj.neuronNumber)));
set(ht,'FontName', 'Arial','FontWeight','bold','FontSize',10);
subplot(2,4,3); fitObj.plotInvGausTrans;
subplot(2,4,4); fitObj.plotSeqCorr;
subplot(2,4,[7 8]); fitObj.plotResidual;
subplot(2,4,[5 6]); fitObj.plotCoeffs;
end
function handle = KSPlot(fitObj,fitNum)
if(nargin<2)
fitNum=1:fitObj.numResults;
end
h=gcf;
figure(h);
N = length(fitObj.KSStats.KSSorted);
if(~isempty(fitObj.KSStats.xAxis))
xaxis = fitObj.KSStats.xAxis(:,1);
plot(xaxis,xaxis, 'k-.'); hold on;
plot(xaxis, xaxis+1.36/sqrt(N), 'r','Linewidth',1);
plot(xaxis,xaxis-1.36/sqrt(N), 'r','Linewidth',1 );
handle=plot(fitObj.KSStats.xAxis(:,fitNum),fitObj.KSStats.KSSorted(:,fitNum),'Linewidth',2);
axis( [0 1 0 1] );
dataLabels = fitObj.lambda.dataLabels(fitNum);
h_legend=legend(handle,dataLabels,'Location','SouthEast');
set(h_legend,'FontSize',14)
end
hx=xlabel('Ideal Uniform CDF');
hy=ylabel('Empirical CDF');
title({'KS Plot of Rescaled ISIs'; 'with 95% Confidence Intervals'},'FontWeight','bold','FontSize',11,'FontName','Arial');
set([hx, hy],'FontName', 'Arial','FontWeight','bold','FontSize',12);
set(gca, ...
'TickLength' , [.02 .02] , ...
'YTick' , 0:.2:1, ...
'XTick' , 0:.2:1, ...
'LineWidth' , 1 );
end
function structure = toStructure(fitObj)
fnames = fieldnames(fitObj);
for i=1:length(fnames)
currObj = fitObj.(fnames{i});
if(strcmp(fnames{i},'histObjects')||strcmp(fnames{i},'ensHistObjects'))
for j=1:fitObj.numResults
tempObj = fitObj.(fnames{i}){j};
if(~isempty(tempObj))
structure.(fnames{i}){j} = tempObj.toStructure;
else
structure.(fnames{i}){j} = tempObj;
end
end
elseif(strcmp(fnames{i},'invGausStats'))
tempNames = fieldnames(fitObj.(fnames{i}));
for j=1:length(tempNames)
tempObj = currObj.(tempNames{j});
if(~isempty(tempObj))
structure.(fnames{i}).(tempNames{j})= tempObj.dataToStructure;
else
structure.(fnames{i}).(tempNames{j})= tempObj;
end
end
else
if(isa(currObj,'double')||isa(currObj,'cell'))
structure.(fnames{i}) = currObj;
elseif(isa(currObj,'Covariate') ||isa(currObj,'ConfigColl')||isa(currObj,'nspikeTrain'))
structure.(fnames{i}) = currObj.toStructure;
elseif(isa(currObj,'SignalObj'))
structure.(fnames{i}) = currObj.dataToStructure;
elseif(isa(currObj,'struct'))
structure.(fnames{i}) = currObj;
end
end
end
end
function handle = plotSeqCorr(fitObj)
rho=zeros(1,fitObj.numResults);
pval=zeros(1,fitObj.numResults);
dataLabels = fitObj.lambda.dataLabels;
for i=1:fitObj.numResults
handle = plot(fitObj.U(1:end-1,i),fitObj.U(2:end,i),strcat('.',Analysis.colors{mod(i-1,length(Analysis.colors))+1})); hold on;
[rhoTemp,p]= corrcoef(fitObj.U(1:end-1,i),fitObj.U(2:end,i));
[~,columns]=size(rhoTemp);
if(columns>1)
rho(i) = rhoTemp(1,2);
pval(i)= p(1,2);
else
rho(i) = rhoTemp;
pval(i)= p;
end
dataLabels{i} = strcat(dataLabels{i},', \rho=',num2str(rho(i),'%0.2g'),' (p=',num2str(pval(i),'%0.2g'),')');
end
h_legend=legend(dataLabels,'Location','NorthEast');
set(h_legend,'FontSize',14)
pos = get(h_legend,'position');
if(~isempty(pos))
set(h_legend, 'position',[pos(1)+.05 pos(2) pos(3:4)]);
end
hy=ylabel('u_{j+1}'); hx=xlabel('u_j');
set([hx, hy],'FontName', 'Arial','FontSize',12,'FontWeight','bold');
axis([0 1 0 1]);
title({'Sequential Correlation of'; 'Rescaled ISIs'},'FontWeight','bold',...
'FontSize',11,...
'FontName','Arial');
set(gca, ...
'TickLength' , [.02 .02] , ...
'YTick' , 0:.25:1, ...
'XTick' , 0:.25:1, ...
'LineWidth' , 1 );
end
function handle = plotInvGausTrans(fitObj)
[fitObj.X,rhoSig,confBoundSig] = Analysis.computeInvGausTrans(fitObj.Z);
n=length(fitObj.X);
handle=[];
if(~isempty(rhoSig))
rhoSig.plot;
end
h_legend=legend(fitObj.lambda.dataLabels,'Location','NorthEast');
set(h_legend,'FontSize',14)
pos = get(h_legend,'position');
if(~isempty(pos))
set(h_legend, 'position',[pos(1)+.05 pos(2) pos(3:4)]);
end
hold on;
if(~isempty(confBoundSig))
confBoundSig.plot;
end
title({'Autocorrelation Function';'of Rescaled ISIs'; 'with 95% CIs'},'FontWeight','bold',...
'FontSize',11,...
'FontName','Arial');
hx=get(gca,'XLabel'); hy=get(gca,'YLabel');
set([hx, hy],'FontName', 'Arial','FontSize',12,'FontWeight','bold');
set(gca, ...
'TickLength' , [.02 .02] , ...
'LineWidth' , 1 );
v=axis;
maxY = max(abs(v(3:4)))*(1.1);
axis([v(1:2) -maxY maxY]);
end
function handle = plotResidual(fitObj)
handle=fitObj.Residual.plot;
legend off;
h_legend=legend(fitObj.lambda.dataLabels,'Location','NorthEast');
set(h_legend,'FontSize',14)
pos = get(h_legend,'position');
set(h_legend, 'position',[pos(1)+.05 pos(2) pos(3:4)]);
title('Point Process Residual','FontWeight','bold',...
'FontSize',11,...
'FontName','Arial');
xlabel('time [s]','Interpreter','none');
hx=get(gca,'XLabel'); hy=get(gca,'YLabel');
set([hx, hy],'FontName', 'Arial','FontSize',12,'FontWeight','bold');
v=axis;
maxY = max(abs(v(3:4)))*(1.1);
axis([v(1:2) -maxY maxY]);
end
function setKSStats(fitObj, Z, U, xAxis, KSSorted, ks_stat)
fitObj.Z =Z;
fitObj.U =U;
fitObj.KSStats.xAxis =xAxis;
fitObj.KSStats.KSSorted =KSSorted;
for i=1:size(xAxis,2);
[differentDists(i),pVal(i),ks_stat(i)]=kstest2(fitObj.KSStats.xAxis(:,i) ,fitObj.KSStats.KSSorted(:,i));
end
if(~exist('differentDists'))
differentDists=1;
end
if(~exist('pVal'))
pVal=1;
end
fitObj.KSStats.ks_stat =ks_stat;
fitObj.KSStats.withinConfInt = ~differentDists;
fitObj.KSStats.pValue = pVal;
end
function setInvGausStats(fitObj, X,rhoSig,confBoundSig)
fitObj.X=X;
fitObj.invGausStats.rhoSig=rhoSig;
fitObj.invGausStats.confBoundSig=confBoundSig;
end
function setFitResidual(fitObj,M)
fitObj.Residual = M;
end
function [paramVals, paramSE, paramSigIndex] = getParam(fitObj,paramNames,fitNum)
if(nargin<3)
fitNum = 1:fitObj.numResults;
end
if(isempty(fitObj.plotParams))
fitObj.computePlotParams;
end
paramVals = zeros(length(paramNames),length(fitNum));
if(nargout>1)
paramSE = zeros(length(paramNames),length(fitNum));
end
if(nargout>2)
paramSigIndex = zeros(length(paramNames),length(fitNum));
end
for i=1:length(paramNames)
paramIndex=find(strcmp(paramNames(i),fitObj.uniqueCovLabels));
paramVals(i,:) = fitObj.plotParams.bAct(paramIndex,fitNum);
if(nargout>1)
paramSE(i,:) = fitObj.plotParams.seAct(paramIndex,fitNum);
end
if(nargout>2)
paramSigIndex(i,:) = fitObj.plotParams.sigIndex(paramIndex,fitNum);
end
end
end
end
methods (Static)
function fitObj = fromStructure(structure)
if(isa(structure,'struct'))
if(isa(structure.neuralSpikeTrain,'cell'))
spikeObj = cell(1,length(structure.neuralSpikeTrain));
for k=1:length(structure.neuralSpikeTrain)
spikeObj{k} = nspikeTrain.fromStructure(structure.neuralSpikeTrain{k});
end
else
spikeObj=nspikeTrain.fromStructure(structure.neuralSpikeTrain);
end
lambda=Covariate.fromStructure(structure.lambda);
rhoSig=SignalObj.signalFromStruct(structure.invGausStats.rhoSig);
confBoundSig = SignalObj.signalFromStruct(structure.invGausStats.confBoundSig);
M = Covariate.fromStructure(structure.Residual);
for i=1:structure.numResults
histObjects{i} = History.fromStructure(structure.histObjects{i});
ensHistObject{i} = History.fromStructure(structure.ensHistObjects{i});
end
configColl = ConfigColl.fromStructure(structure.configs);
logLL = NaN(1,max(1,structure.numResults));
if isfield(structure,'logLL')
logLL = structure.logLL;
if isscalar(logLL) && structure.numResults > 1
logLL = repmat(logLL,1,structure.numResults);
end
end
XvalData = [];
if isfield(structure,'XvalData')
XvalData = structure.XvalData;
end
XvalTime = [];
if isfield(structure,'XvalTime')
XvalTime = structure.XvalTime;
end
fitType = '';
if isfield(structure,'fitType')
fitType = structure.fitType;
elseif isfield(structure,'distribution')
fitType = structure.distribution;
end
fitObj=FitResult(spikeObj,structure.covLabels,structure.numHist,histObjects,ensHistObject,lambda,structure.b, structure.dev, structure.stats,structure.AIC,structure.BIC,logLL,configColl,XvalData,XvalTime,fitType);
fitObj.setKSStats(structure.Z,structure.U, structure.KSStats.xAxis, structure.KSStats.KSSorted, structure.KSStats.ks_stat);
fitObj.setInvGausStats(structure.X,rhoSig,confBoundSig);
fitObj.setFitResidual(M);
fitObj.setNeuronName(structure.neuronNumber);
elseif(isa(structure,'cell'))
fitObj = cell(size(structure));
for i=1:length(structure)
fitObj{i} = FitResult.fromStructure(structure{i});
end
end
end
function structCell = CellArrayToStructure(fitResObjCell)
if(isa(fitResObjCell,'FitResult'))
structCell = fitResObjCell.toStructure;
elseif(isa(fitResObjCell,'cell')&&~isempty(fitResObjCell))
if(isa(fitResObjCell{1},'FitResult'))
structCell = cell(size(fitResObjCell));
for i=1:length(fitResObjCell)
structCell{i} = fitResObjCell{i}.toStructure;
end
end
else
structCell={};
end
end
end
end
function hText = xticklabel_rotate(XTick,rot,varargin)
if isempty(get(gca,'XTickLabel')),
error('xticklabel_rotate : can not process, either xticklabel_rotate has already been run or XTickLabel field has been erased') ;
end
if (nargin < 3 || isempty(varargin{1})) && (~exist('XTick') || isempty(XTick)),
xTickLabels = get(gca,'XTickLabel') ;
if ~iscell(xTickLabels)
temp1 = num2cell(xTickLabels,2) ;
for loop = 1:length(temp1),
temp1{loop} = deblank(temp1{loop}) ;
end
xTickLabels = temp1 ;
end
varargin = varargin(2:length(varargin));
end
if (~exist('XTick') | isempty(XTick)),
XTick = get(gca,'XTick') ;
end
XTick = XTick(:);
if ~exist('xTickLabels'),
if (length(varargin)>0) & (iscell(varargin{1})),
xTickLabels = varargin{1};
varargin = varargin(2:length(varargin));
else
xTickLabels = num2str(XTick);
end
end
if length(XTick) ~= length(xTickLabels),
error('xticklabel_rotate : must have same number of elements in "XTick" and "XTickLabel"') ;
end
set(gca,'XTick',XTick,'XTickLabel','')
if nargin < 2,
rot = 90 ;
end
hxLabel = get(gca,'XLabel');
xLabelString = get(hxLabel,'String');
set(hxLabel,'Units','data');
xLabelPosition = get(hxLabel,'Position');
y = xLabelPosition(2);
y=repmat(y,size(XTick,1),1);
fs = get(gca,'fontsize');
hText = text(XTick, y, xTickLabels,'fontsize',fs);
set(hText,'Rotation',rot,'HorizontalAlignment','right',varargin{:})
set(get(gca,'xlabel'),'units','data') ;
labxorigpos_data = get(get(gca,'xlabel'),'position') ;
set(get(gca,'ylabel'),'units','data') ;
labyorigpos_data = get(get(gca,'ylabel'),'position') ;
set(get(gca,'title'),'units','data') ;
labtorigpos_data = get(get(gca,'title'),'position') ;
set(gca,'units','pixel') ;
set(hText,'units','pixel') ;
set(get(gca,'xlabel'),'units','pixel') ;
set(get(gca,'ylabel'),'units','pixel') ;
origpos = get(gca,'position') ;
temphText = get(hText,'extent');
if(isa(temphText,'cell'))
textsizes = cell2mat(temphText) ;
else
textsizes = temphText;
end
longest = max(textsizes(:,4)) ;
laborigext = get(get(gca,'xlabel'),'extent') ;
laborigpos = get(get(gca,'xlabel'),'position') ;
labyorigext = get(get(gca,'ylabel'),'extent') ;
labyorigpos = get(get(gca,'ylabel'),'position') ;
leftlabdist = labyorigpos(1) + labyorigext(1) ;
leftpos = get(hText(1),'position') ;
leftext = get(hText(1),'extent') ;
leftdist = leftpos(1) + leftext(1) ;
if leftdist > 0, leftdist = 0 ; end
botdist = origpos(2) + laborigpos(2) ;
newpos = [origpos(1)-leftdist longest+botdist origpos(3)+leftdist origpos(4)-longest+origpos(2)-botdist] ;
set(gca,'position',newpos) ;
set(hText,'units','data') ;
for loop= 1:length(hText),
set(hText(loop),'position',[XTick(loop), y(loop)]) ;
end
laborigpos = get(get(gca,'xlabel'),'position') ;
set(get(gca,'xlabel'),'position',[laborigpos(1) laborigpos(2)-longest 0]) ;
set(get(gca,'ylabel'),'units','data') ;
set(get(gca,'ylabel'),'position',labyorigpos_data) ;
set(get(gca,'title'),'position',labtorigpos_data) ;
set(get(gca,'xlabel'),'units','data') ;
labxorigpos_data_new = get(get(gca,'xlabel'),'position') ;
set(get(gca,'xlabel'),'position',[labxorigpos_data(1) labxorigpos_data_new(2)]) ;
set(get(gca,'xlabel'),'units','normalized') ;
set(get(gca,'ylabel'),'units','normalized') ;
set(get(gca,'title'),'units','normalized') ;
set(hText,'units','normalized') ;
set(gca,'units','normalized') ;
if nargout < 1,
clear hText
end
end
function [uniqueLabels, indexIntoOriginal, restoreIndex] = getUniqueLabels(covLabels)
offset = 0;
for i=1:length(covLabels)
currLabels = covLabels{i};
allLabels((1:length(currLabels))+offset) = currLabels;
offset=length(allLabels);
end
[uniqueLabels, indexIntoOriginal, restoreIndex] = unique(allLabels);
end