Skip to content

Commit

Permalink
Minor upgrade to ADEM routines
Browse files Browse the repository at this point in the history
SVN r3333
  • Loading branch information
Friston committed Aug 25, 2009
1 parent 00ce090 commit ba00d49
Show file tree
Hide file tree
Showing 12 changed files with 581 additions and 31 deletions.
32 changes: 16 additions & 16 deletions spm_ADEM_M_set.m
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 49,7 @@
% Copyright (C) 2005 Wellcome Department of Imaging Neuroscience

% Karl Friston
% $Id: spm_ADEM_M_set.m 1961 2008-07-26 09:38:46Z karl $
% $Id: spm_ADEM_M_set.m 3333 2009-08-25 16:12:44Z karl $

% order
%--------------------------------------------------------------------------
Expand Down Expand Up @@ -127,7 127,7 @@

% Assume fixed parameters if not specified
%----------------------------------------------------------------------
if length(M(i).pC) == 0
if isempty(M(i).pC)
p = length(spm_vec(M(i).pE));
M(i).pC = sparse(p,p);
end
Expand All @@ -154,12 154,12 @@
catch
v = sparse(0,0);
end
if ~length(v)
if isempty(v)
try
v = sparse(M(g - 1).m,1);
end
end
if ~length(v)
if isempty(v)
try
v = sparse(M(g).l,1);
end
Expand All @@ -175,7 175,7 @@
catch
a = sparse(0,0);
end
if ~length(a)
if isempty(a)
try
a = sparse(M(i).k,1);
end
Expand All @@ -193,7 193,7 @@
catch
x = sparse(M(i).n,1);
end
if ~length(x) && M(i).n
if isempty(x) && M(i).n
x = sparse(M(i).n,1);
end

Expand All @@ -205,7 205,7 @@
try
f = feval(M(i).f,x,v,a,M(i).pE);
if length(spm_vec(x)) ~= length(spm_vec(f))
errordlg('please check nargout: M(%i).f(x,v,a,P)',i);
errordlg(sprintf('please check nargout: M(%i).f(x,v,a,P)',i));
end
catch
errordlg(sprintf('evaluation failure: M(%i).f(x,v,a,P)',i))
Expand Down Expand Up @@ -236,7 236,7 @@
% remove empty levels
%--------------------------------------------------------------------------
try
g = min(find(~spm_vec(M.m)));
g = find(~spm_vec(M.m),1);
M = M(1:g);
catch
errordlg('please specify number of variables')
Expand All @@ -263,8 263,8 @@

% make sure components are cell arrays
%----------------------------------------------------------------------
if length(M(i).Q) & ~iscell(M(i).Q), M(i).Q = {M(i).Q}; end
if length(M(i).R) & ~iscell(M(i).R), M(i).R = {M(i).R}; end
if ~isempty(M(i).Q) && ~iscell(M(i).Q), M(i).Q = {M(i).Q}; end
if ~isempty(M(i).R) && ~iscell(M(i).R), M(i).R = {M(i).R}; end

% check hyperpriors
%======================================================================
Expand All @@ -276,16 276,16 @@

% check hyperpriors (expectations)
%----------------------------------------------------------------------
if ~length(M(i).hE), M(i).hE = sparse(length(M(i).Q),1); end
if ~length(M(i).gE), M(i).gE = sparse(length(M(i).R),1); end
if isempty(M(i).hE), M(i).hE = sparse(length(M(i).Q),1); end
if isempty(M(i).gE), M(i).gE = sparse(length(M(i).R),1); end

% check hyperpriors (covariances)
%----------------------------------------------------------------------
try, M(i).hC*M(i).hE; catch, M(i).hC = speye(length(M(i).hE))*256; end
try, M(i).gC*M(i).gE; catch, M(i).gC = speye(length(M(i).gE))*256; end

if ~length(M(i).hC), M(i).hC = speye(length(M(i).hE))*256; end
if ~length(M(i).gC), M(i).gC = speye(length(M(i).gE))*256; end
if isempty(M(i).hC), M(i).hC = speye(length(M(i).hE))*256; end
if isempty(M(i).gC), M(i).gC = speye(length(M(i).gE))*256; end

% check Q and R (precision components)
%======================================================================
Expand Down Expand Up @@ -339,7 339,7 @@

% remove fixed components if hyperparameters exist
%----------------------------------------------------------------------
if length(M(i).hE)
if ~isempty(M(i).hE)
M(i).V = sparse(M(i).l,M(i).l);
end

Expand All @@ -355,7 355,7 @@

% remove fixed components if hyperparameters exist
%----------------------------------------------------------------------
if length(M(i).gE)
if ~isempty(M(i).gE)
M(i).W = sparse(M(i).n,M(i).n);
end

Expand Down
12 changes: 5 additions & 7 deletions spm_DEM_M_set.m
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 47,7 @@
% Copyright (C) 2008 Wellcome Trust Centre for Neuroimaging

% Karl Friston
% $Id: spm_DEM_M_set.m 3058 2009-04-09 18:17:53Z karl $
% $Id: spm_DEM_M_set.m 3333 2009-08-25 16:12:44Z karl $

% order
%--------------------------------------------------------------------------
Expand Down Expand Up @@ -126,7 126,7 @@

% Assume fixed parameters if not specified
%----------------------------------------------------------------------
if length(M(i).pC) == 0
if isempty(M(i).pC)
p = length(spm_vec(M(i).pE));
M(i).pC = sparse(p,p);
end
Expand Down Expand Up @@ -187,9 187,7 @@
try
f = feval(M(i).f,x,v,M(i).pE);
if length(spm_vec(x)) ~= length(spm_vec(f))
str = sprintf('please check: M(%i).f(x,v,P)',i);
msgbox(str)
error(' ')
errordlg(sprintf('please check: M(%i).f(x,v,P)',i));
end

catch
Expand Down Expand Up @@ -243,8 241,8 @@

% make sure components are cell arrays
%----------------------------------------------------------------------
if ~isempty(M(i).Q) & ~iscell(M(i).Q), M(i).Q = {M(i).Q}; end
if ~isempty(M(i).R) & ~iscell(M(i).R), M(i).R = {M(i).R}; end
if ~isempty(M(i).Q) && ~iscell(M(i).Q), M(i).Q = {M(i).Q}; end
if ~isempty(M(i).R) && ~iscell(M(i).R), M(i).R = {M(i).R}; end

% check hyperpriors
%======================================================================
Expand Down
4 changes: 2 additions & 2 deletions spm_cat.m
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 22,7 @@
% Copyright (C) 2008 Wellcome Trust Centre for Neuroimaging

% Karl Friston
% $Id: spm_cat.m 1172 2008-02-27 20:14:47Z karl $
% $Id: spm_cat.m 3333 2009-08-25 16:12:44Z karl $

% check x is not already a matrix
%--------------------------------------------------------------------------
Expand Down Expand Up @@ -80,7 80,7 @@
[n m] = size(x);
for i = 1:n
for j = 1:m
if ~length(x{i,j})
if isempty(x{i,j})
x{i,j} = sparse(I(i),J(j));
end
end
Expand Down
4 changes: 2 additions & 2 deletions toolbox/DEM/ADEM_mountaincar_loss_3.m
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 12,7 @@
% Copyright (C) 2008 Wellcome Trust Centre for Neuroimaging

% Karl Friston
% $Id: ADEM_mountaincar_loss_3.m 3140 2009-05-21 18:38:17Z karl $
% $Id: ADEM_mountaincar_loss_3.m 3333 2009-08-25 16:12:44Z karl $

% generative process (mountain car terrain)
%========================================================================== % switch for demo
Expand Down Expand Up @@ -162,7 162,7 @@

% loss-functions or priors
%--------------------------------------------------------------------------
q0 = sparse(nq,1);
q0 = sparse(nq,1) - 1;
q = sparse(nq*3/4,1,-32,nq,1) 1;
C0 = spm_DEM_basis(x,[],q0);
C = spm_DEM_basis(x,[],q);
Expand Down
187 changes: 187 additions & 0 deletions toolbox/DEM/ADEM_mountaincar_loss_4.m
Original file line number Diff line number Diff line change
@@ -0,0 1,187 @@
% This demo re-visits the mountain car problem to show that adaptive
% (desired) behaviour can be prescribed in terms of loss-functions (i.e.
% reward functions of state-space).
% It exploits the fact that under the free-energy formulation, loss is
% divergence. This means that priors can be used to make certain parts of
% state-space costly (i.e. with high divergence) and others rewarding (low
% divergence). Active inference under these priors will lead to sampling of
% low cost states and (apparent) attractiveness of those states.
%
% This is version four; that includes a drive state.
%__________________________________________________________________________
% Copyright (C) 2008 Wellcome Trust Centre for Neuroimaging

% Karl Friston
% $Id: ADEM_mountaincar_loss_4.m 3333 2009-08-25 16:12:44Z karl $

% generative process (mountain car terrain)
%========================================================================== % switch for demo
clear

% parameters of generative process
%--------------------------------------------------------------------------
P.a = 0;
P.b = [0 0];
P.c = [0 0 0 0];
P.d = 1; % action on

fx = inline('spm_mc_fxa_4(x,v,a,P)','x','v','a','P');
gx = inline('[x.x; x.v; x.d]','x','v','a','P');
x0.x = 0;
x0.v = 0;
x0.p = 0;
x0.d = 0;


% level 1
%--------------------------------------------------------------------------
G(1).x = x0;
G(1).f = fx;
G(1).g = gx;
G(1).pE = P;
G(1).V = exp(16); % error precision
G(1).W = exp(16); % error precision

% level 2
%--------------------------------------------------------------------------
G(2).a = 0; % action
G(2).v = 0; % inputs
G(2).V = exp(16);
G = spm_ADEM_M_set(G);


% generative model
%==========================================================================
clear P x0

% parameters (previously learned) and equations of motion
%--------------------------------------------------------------------------
P = [2.7 1.7 0.74 -0.51 -0.85 0.08 -0.23 -1.15];
np = length(P);
fx = inline('spm_mc_fx_4(x,v,P)','x','v','P');
gx = inline('[x.x; x.v; x.d]','x','v','P');
x0.x = 0;
x0.v = 0;
x0.c = 0;
x0.d = 0;

% level 1
%--------------------------------------------------------------------------
M(1).x = x0;
M(1).f = fx;
M(1).g = gx;
M(1).pE = P;
M(1).V = exp(8); % error precision
M(1).W = diag(exp([8 4 16 16])); % error precision

% level 2
%--------------------------------------------------------------------------
M(2).v = 0; % inputs
M(2).V = exp(16);
M = spm_DEM_M_set(M);


% learn gradients with a flat loss-functions (priors on divergence)
%==========================================================================
N = 1600;
U = sparse(N,M(1).m);
DEM.U = U;
DEM.C = U;
DEM.G = G;
DEM.M = M;
DEM = spm_ADEM(DEM);


% show dynamics
%==========================================================================

% inference
%--------------------------------------------------------------------------
spm_figure('GetWin','Graphics');
spm_DEM_qU(DEM.qU)

% true and inferred position
%--------------------------------------------------------------------------
subplot(2,2,3)
plot(DEM.pU.x{1}(1,:),DEM.pU.x{1}(2,:)),hold on
plot( 1,0,'r.','Markersize',32), hold on
plot(-1/2,0,'g.','Markersize',16), hold off
xlabel('position','Fontsize',14)
ylabel('velcitiy','Fontsize',14)
title('trajectories','Fontsize',16)
axis([-1 1 -1 1]*3)
axis square

% true position
%--------------------------------------------------------------------------
subplot(2,2,3)
plot3(DEM.pU.x{1}(1,:),DEM.pU.x{1}(2,:),1:N), hold on
plot3( 1,0,1:64:N,'r.','Markersize',8), hold on
plot3(-1/2,0,1:64:N,'g.','Markersize',8), hold off
xlabel('position','Fontsize',14)
ylabel('velocity','Fontsize',14)
zlabel('time','Fontsize',14)
title('trajectories','Fontsize',16)
axis([-2 2 -2 2 0 N])
axis square


% real states
%==========================================================================
spm_figure('GetWin','DEM');
spm_DEM_qU(DEM.pU)


subplot(2,2,1)
plot3( 1,0,1:1/8:2,'r.'), hold on
plot3(-1/2,0,1:1/8:2,'g.'), hold on
plot3(DEM.pU.x{1}(1,:),DEM.pU.x{1}(2,:),DEM.pU.x{1}(4,:)), hold off
xlabel('position','Fontsize',14)
ylabel('velocity','Fontsize',14)
zlabel('satiety','Fontsize',14)
title('trajectories','Fontsize',16)
axis([-2 2 -2 2 0 8])



% cost function (see spm_mc_fx_4.m)
%--------------------------------------------------------------------------
subplot(2,1,2)
x = -2:1/64:2;
d = 0:1/64:2;
for i = 1:length(x)
for j = 1:length(d)
D = spm_phi((1 - d(j))*8);
A = 2 - 32*exp(-(x(i) - 1).^2*32);
C(i,j) = A*D - 1;
end
end

surf(d,x,C)
shading interp
xlabel('drive','Fontsize',14)
ylabel('position','Fontsize',14)
title('cost-function','Fontsize',16)
axis square

return

% NOTES for graphics
%--------------------------------------------------------------------------
clear C
r = 0:1/64:1;
d = 0:1/64:4;
for i = 1:length(r)
for j = 1:length(d)
D = spm_phi((1 - d(j))*8);
A = 2 - 32*r(i);
C(i,j) = A*D - 1;
end
end

imagesc(d,r,C)
shading interp
xlabel('satiety','Fontsize',14)
ylabel('reward','Fontsize',14)
axis xy

Loading

0 comments on commit ba00d49

Please sign in to comment.