State observer #309
-
I have a simple spring-mass-damper system for which I want to make a state-observer: using RxInfer
using Random
using LinearAlgebra
using CairoMakie
function f(x, u, p, t, tS)
return x tS*[
x[2]
-p.c / p.m * x[1] - p.d / p.m * x[2] - 1 / p.m * u[1]
]
end
g(x, u, p, t) = [x[1]]
tS = 1/100
t_sim = 0:tS:120
u_sim = reduce(
hcat,
map(t -> [exp(-0.1t) * sin(1 / 10 * 2 * pi * t) sin(1 / 20 * 2 * pi * t)], t_sim),
)
x0 = [0, 0]
function generate_data(rng, f, g, tS, x0, u, t, Q, R)
x_prev = x0
x = Vector{Vector{Float64}}(undef, length(t))
y = Vector{Vector{Float64}}(undef, length(t))
for i in 1:length(t)
x[i] = rand(rng, MvNormal(f(x_prev, u[:,i], p, t[i], tS), Q))
y[i] = rand(rng, MvNormal(g(x[i], u[:,i], p, t[i]), R))
x_prev = x[i]
end
return x, y
end
seed = 1234
rng = MersenneTwister(1234)
Q = 0.1 .* diageye(2)
R = 10 .* diageye(1)
p = (c=1.0, d=0.1, m=0.8)
x, y = generate_data(rng, f, g, tS, x0, u_sim, t_sim, Q, R)
let
fig = Figure()
ax1 = Axis(fig[1,1])
lines!(t_sim, getindex.(x, 1), label = "Hidden Signal (dim-1)", color = :teal)
lines!(t_sim, getindex.(x, 2), label = "Hidden Signal (dim-2)", color = :violet)
ax2 = Axis(fig[2,1])
lines!(t_sim, getindex.(y, 1), label = "Observation", color = :orange)
fig
end However, I have some problems with the model specification: @model function ssm(y, f, g, tS, x0, u, t, Q, R)
x_prior ~ x0
x_prev = x_prior
for i in eachindex(t)
x[i] ~ MvNormalMeanCovariance(mu=f(x_prev, u[:,i], p, t[i], tS), var=Q)
y[i] ~ MvNormalMeanCovariance(mu=g(x[i], u[:,i], p, t[i]), var=R)
x_prev = x[i]
end
end
# prior_x: Based on first observation, assuming initial state is similar with equal variance.
prior_x = MvNormalMeanCovariance(x[1], diageye(2))
model = ssm(f=f, g=g, tS=tS, x0=prior_x, u=u_sim, t=t_sim, Q=Q, R=R)
ssm_meta = @meta begin
f() -> Unscented()
g() -> Unscented()
end
imessages = @initialization begin
μ(x) = prior_x
end
result = infer(
model = model,
data = (y=y, ),
options = (limit_stack_depth = 500, ),
returnvars = KeepLast(),
predictvars = KeepLast(),
initialization = imessages,
meta = ssm_meta,
iterations = 20,
showprogress=true,
) ERROR: AssertionError: Expected only one missing interface, got (:out, :μ, :Σ) of length 3 (node MvNormal with interfaces (:mu, :var))
Stacktrace:
[1] prepare_interfaces(::GraphPPL.StaticInterfaces{…}, fform::Type{…}, lhs_interface::GraphPPL.NodeLabel, rhs_interfaces::@NamedTuple{…})
@ GraphPPL C:\Users\KaisermayerV\.julia\packages\GraphPPL\Z49xA\src\graph_engine.jl:1622 Eventually, I would like to infer an unknown |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 19 replies
-
It should be: @model function ssm(y, f, g, tS, x0, u, t, Q, R)
x_prior ~ x0
x_prev = x_prior
for i in eachindex(t)
x[i] ~ MvNormalMeanCovariance(μ=f(x_prev, u[:,i], p, t[i], tS), Σ=Q)
y[i] ~ MvNormalMeanCovariance(μ=g(x[i], u[:,i], p, t[i]), Σ=R)
x_prev = x[i]
end
end However, it does not terminate. |
Beta Was this translation helpful? Give feedback.
-
I have something working; however, the results are a bit underwhelming, considering that this is just the same as the linear Kalman filter example. result = infer(
model = model,
data = (y=y, ),
options = (limit_stack_depth = 10, ),
initialization = imessages,
meta = ssm_meta,
iterations = 5,
showprogress=true,
)
xmarginals = result.posteriors[:x][1] # why is this a 1-element vector?
let
fig = Figure(size=(800,400))
ax1 = Axis(fig[1,1])
lines!(t_sim, getindex.(x, 1), label = "Hidden Signal (dim-1)", color = :orange)
lines!(t_sim, getindex.(x, 2), label = "Hidden Signal (dim-2)", color = :green)
band!(t_sim, getindex.(mean.(xmarginals), 1) .- (getindex.(var.(xmarginals), 1) .|> sqrt), getindex.(mean.(xmarginals), 1) . (getindex.(var.(xmarginals), 1) .|> sqrt), label = "Estimated Signal (dim-1)", color = (:violet, 0.5))
band!(t_sim, getindex.(mean.(xmarginals), 2) .- (getindex.(var.(xmarginals), 2) .|> sqrt), getindex.(mean.(xmarginals), 2) . (getindex.(var.(xmarginals), 2) .|> sqrt), label = "Estimated Signal (dim-2)", color = (:teal, 0.5))
lines!(t_sim, getindex.(mean.(xmarginals), 1), label = "Estimated Signal (dim-1)", color = :violet)
lines!(t_sim, getindex.(mean.(xmarginals), 2), label = "Estimated Signal (dim-2)", color = :teal)
Legend(fig[1,2], ax1, merge=true, unique=false)
fig
end |
Beta Was this translation helpful? Give feedback.
-
I tested the solution, and the progress bar shows 8 s, however, |
Beta Was this translation helpful? Give feedback.
@ValentinKaisermayer I decided to play with your model a bit to also infer the p parameter and employ @bvdmitri's suggestion to remove the g function. For me, the inference now runs for 20 seconds, taking 1 second per iteration.