Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Heat Equation Enzyme Fail #183

Open
jClugstor opened this issue Oct 11, 2024 · 0 comments
Open

Heat Equation Enzyme Fail #183

jClugstor opened this issue Oct 11, 2024 · 0 comments

Comments

@jClugstor
Copy link
Collaborator

This was working before, on previous versions of Enzyme, DiffEq etc.

using Decapodes
using DiagrammaticEquations
using CombinatorialSpaces
using GeometryBasics
using MLStyle
using ComponentArrays
using OrdinaryDiffEq
using CairoMakie
import CairoMakie: wireframe, mesh, Figure, Axis
import Decapodes:nparts, mul!
using Zygote
using SciMLSensitivity
using DataFrames
using Chairmarks

Point2D = Point2{Float64}
Point3D = Point3{Float64}

rect = triangulated_grid(100, 100, 2, 2, Point3D)
d_rect = EmbeddedDeltaDualComplex2D{Bool,Float64,Point3D}(rect)
subdivide_duals!(d_rect, Circumcenter())

#fig = Figure()
#ax = CairoMakie.Axis(fig[1, 1])
#wf = wireframe!(ax, rect)
#fig

Heat = @decapode begin
    U::Form0

    k::Constant
    c::Constant
    ρ::Constant

    ∂ₜ(U) == k/(c*ρ) * Δ(U)
end


decapode_code = gensim(Heat, preallocate=true)
file = open("/home/jadonclugston/Documents/Work/dev/DecapodeCalibrateDemos/HeatEquation/Heat_alloc.jl", "w")
write(file, string("decapode_f = ", decapode_code))
close(file)
include("/home/jadonclugston/Documents/Work/dev/DecapodeCalibrateDemos/HeatEquation/Heat_alloc.jl")

#fₘ  = evalsim(Heat, preallocate=true)

fₘ = decapode_f(d_rect, nothing)

U = map(d_rect[:point]) do (x, _)
    return x
end



#fig = Figure()
#ax = CairoMakie.Axis(fig[1, 1])
#msh = mesh!(ax, rect, color=U, colormap=:jet)
#fig


# Aluminum Pure https://ncfs.ucf.edu/burn_db/Thermal_Properties/material_thermal.html
# k = 237, ρ = 2702, Cp = 0.903, ϵ = 0.03
u₀ = ComponentArray(U=U)

constants_and_parameters = ComponentArray(k = 100, c = 1, ρ = 1)

tₑ = 20.0

@info("Precompiling Solver")
prob = ODEProblem(fₘ, u₀, (0, 1e-4), constants_and_parameters)
soln = solve(prob, Tsit5())
soln.retcode != :Unstable || error("Solver was not stable")
@info("Solving")
data_prob = ODEProblem(fₘ, u₀, (0, tₑ), constants_and_parameters)
soln = solve(data_prob, Tsit5())
@info("Done")


#fig = Figure()
#ax = CairoMakie.Axis(fig[1, 1])
#msh = mesh!(ax, rect, color=soln(tₑ).U, colormap=:jet)
#fig

reference_dat = last(soln).U

function loss(u) #only compares last time step
    newp = ComponentArray(k=u[1], c=1, ρ=1)
    prob = remake(data_prob, p=newp)
    sol = solve(prob, Tsit5(), sensealg = GaussAdjoint(autojacvec = EnzymeVJP()))
    current_dat = last(sol).U
    sum(abs2, reference_dat .- current_dat)
end

Zygote.gradient(loss, 99.0)

Errors on the Zygote.gradient call

ERROR: MethodError: no method matching augmented_primal(::EnzymeCore.EnzymeRules.RevConfigWidth{…}, ::EnzymeCore.Const{…}, ::Type{…}, ::EnzymeCore.Duplicated{…}, ::EnzymeCore.Duplicated{…}, ::EnzymeCore.Duplicated{…}, ::EnzymeCore.Const{…}, ::EnzymeCore.Const{…})

Closest candidates are:
  augmented_primal(::EnzymeCore.EnzymeRules.RevConfig, ::EnzymeCore.Const{typeof(mul!)}, ::Type{RT}, ::EnzymeCore.Annotation{<:StridedVecOrMat}, ::EnzymeCore.Const{<:Union{SparseArrays.AbstractSparseMatrixCSC{Tv, Ti}, SubArray{Tv, 2, <:SparseArrays.AbstractSparseMatrixCSC{Tv, Ti}, Tuple{Base.Slice{Base.OneTo{Int64}}, I}} where I<:AbstractUnitRange} where {Tv, Ti}}, ::EnzymeCore.Annotation{<:StridedVecOrMat}, ::EnzymeCore.Annotation{<:Number}, ::EnzymeCore.Annotation{<:Number}) where RT
   @ Enzyme ~/.julia/packages/Enzyme/Vjlrr/src/internal_rules.jl:732
  augmented_primal(::Any, ::EnzymeCore.Const{typeof(QuadGK.quadgk)}, ::Type{RT}, ::Any, ::EnzymeCore.Annotation{T}...; kws...) where {RT, T}
   @ QuadGKEnzymeExt ~/.julia/packages/QuadGK/BjmU0/ext/QuadGKEnzymeExt.jl:6
  augmented_primal(::Any, ::EnzymeCore.Const{typeof(NNlib._dropout!)}, ::Type{RT}, ::Any, ::OutType, ::Any, ::Any, ::Any) where {OutType, RT}
   @ NNlibEnzymeCoreExt ~/.julia/packages/NNlib/CkJqS/ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl:318
  ...

Stacktrace:
  [1] custom_rule_method_error
    @ ~/.julia/packages/Enzyme/Vjlrr/src/rules/customrules.jl:452 [inlined]
  [2] mul!
    @ ~/.julia/juliaup/julia-1.10.4 0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:237 [inlined]
  [3] f
    @ ~/Documents/Work/dev/DecapodeCalibrateDemos/HeatEquation/Heat_alloc.jl:45
  [4] ODEFunction
    @ ~/.julia/packages/SciMLBase/tEuIM/src/scimlfunctions.jl:2336 [inlined]
  [5] #139
    @ ~/.julia/packages/SciMLSensitivity/HRhwU/src/adjoint_common.jl:479 [inlined]
  [6] diffejulia__139_22024_inner_1wrap
    @ ~/.julia/packages/SciMLSensitivity/HRhwU/src/adjoint_common.jl:0
  [7] macro expansion
    @ ~/.julia/packages/Enzyme/Vjlrr/src/compiler.jl:8839 [inlined]
  [8] enzyme_call
    @ ~/.julia/packages/Enzyme/Vjlrr/src/compiler.jl:8405 [inlined]
  [9] CombinedAdjointThunk
    @ ~/.julia/packages/Enzyme/Vjlrr/src/compiler.jl:8178 [inlined]
 [10] autodiff
    @ ~/.julia/packages/Enzyme/Vjlrr/src/Enzyme.jl:491 [inlined]
 [11] _vecjacobian!(dλ::Vector{…}, y::ComponentVector{…}, λ::Vector{…}, p::ComponentVector{…}, t::Float64, S::SciMLSensitivity.ODEGaussAdjointSensitivityFunction{…}, isautojacvec::EnzymeVJP, dgrad::Nothing, dy::Nothing, W::Nothing)
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/HRhwU/src/derivative_wrappers.jl:710
 [12] #vecjacobian!#18
    @ ~/.julia/packages/SciMLSensitivity/HRhwU/src/derivative_wrappers.jl:232 [inlined]
 [13] vecjacobian!
    @ ~/.julia/packages/SciMLSensitivity/HRhwU/src/derivative_wrappers.jl:229 [inlined]
 [14] (::SciMLSensitivity.ODEGaussAdjointSensitivityFunction{…})(du::Vector{…}, u::Vector{…}, p::ComponentVector{…}, t::Float64)
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/HRhwU/src/gauss_adjoint.jl:102
 [15] ODEFunction
    @ ~/.julia/packages/SciMLBase/tEuIM/src/scimlfunctions.jl:2336 [inlined]
 [16] initialize!(integrator::OrdinaryDiffEqCore.ODEIntegrator{…}, cache::OrdinaryDiffEqTsit5.Tsit5Cache{…})
    @ OrdinaryDiffEqTsit5 ~/.julia/packages/OrdinaryDiffEqTsit5/DHYtz/src/tsit_perform_step.jl:175
 [17] __init(prob::ODEProblem{…}, alg::Tsit5{…}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{…}; saveat::Vector{…}, tstops::Vector{…}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Bool, callback::CallbackSet{…}, dense::Bool, calck::Bool, dt::Float64, dtmin::Float64, dtmax::Float64, force_dtmin::Bool, adaptive::Bool, gamma::Rational{…}, abstol::Float64, reltol::Float64, qmin::Rational{…}, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64, beta1::Nothing, beta2::Nothing, qoldinit::Rational{…}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(LinearAlgebra.opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), progress_id::Symbol, userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias_u0::Bool, alias_du0::Bool, initializealg::OrdinaryDiffEqCore.DefaultInit, kwargs::@Kwargs{})
    @ OrdinaryDiffEqCore ~/.julia/packages/OrdinaryDiffEqCore/HwWWN/src/solve.jl:525
 [18] __init (repeats 5 times)
    @ ~/.julia/packages/OrdinaryDiffEqCore/HwWWN/src/solve.jl:11 [inlined]
 [19] #__solve#75
    @ ~/.julia/packages/OrdinaryDiffEqCore/HwWWN/src/solve.jl:6 [inlined]
 [20] __solve
    @ ~/.julia/packages/OrdinaryDiffEqCore/HwWWN/src/solve.jl:1 [inlined]
 [21] solve_call(_prob::ODEProblem{…}, args::Tsit5{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/uqSeD/src/solve.jl:612
 [22] solve_call
    @ ~/.julia/packages/DiffEqBase/uqSeD/src/solve.jl:569 [inlined]
 [23] #solve_up#53
    @ ~/.julia/packages/DiffEqBase/uqSeD/src/solve.jl:1092 [inlined]
 [24] solve_up
    @ ~/.julia/packages/DiffEqBase/uqSeD/src/solve.jl:1078 [inlined]
 [25] #solve#51
    @ ~/.julia/packages/DiffEqBase/uqSeD/src/solve.jl:1015 [inlined]
 [26] _adjoint_sensitivities(sol::ODESolution{…}, sensealg::GaussAdjoint{…}, alg::Tsit5{…}; t::Vector{…}, dgdu_discrete::Function, dgdp_discrete::Nothing, dgdu_continuous::Nothing, dgdp_continuous::Nothing, g::Nothing, abstol::Float64, reltol::Float64, checkpoints::Vector{…}, corfunc_analytical::Bool, callback::Nothing, kwargs::@Kwargs{})
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/HRhwU/src/gauss_adjoint.jl:580
 [27] _adjoint_sensitivities
    @ ~/.julia/packages/SciMLSensitivity/HRhwU/src/gauss_adjoint.jl:533 [inlined]
 [28] #adjoint_sensitivities#63
    @ ~/.julia/packages/SciMLSensitivity/HRhwU/src/sensitivity_interface.jl:401 [inlined]
 [29] (::SciMLSensitivity.var"#adjoint_sensitivity_backpass#313"{})(Δ::ODESolution{…})
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/HRhwU/src/concrete_solve.jl:627
 [30] ZBack
    @ ~/.julia/packages/Zygote/Tt5Gx/src/compiler/chainrules.jl:212 [inlined]
 [31] (::Zygote.var"#294#295"{})(Δ::ODESolution{…})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/lib/lib.jl:206
 [32] (::Zygote.var"#2169#back#296"{})(Δ::ODESolution{…})
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
 [33] #solve#51
    @ ~/.julia/packages/DiffEqBase/uqSeD/src/solve.jl:1015 [inlined]
 [34] (::Zygote.Pullback{…})(Δ::ODESolution{…})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
 [35] #294
    @ ~/.julia/packages/Zygote/Tt5Gx/src/lib/lib.jl:206 [inlined]
 [36] (::Zygote.var"#2169#back#296"{})(Δ::ODESolution{…})
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
 [37] solve
    @ ~/.julia/packages/DiffEqBase/uqSeD/src/solve.jl:1005 [inlined]
 [38] (::Zygote.Pullback{…})(Δ::ODESolution{…})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
 [39] loss
    @ ~/Documents/Work/dev/DecapodeCalibrateDemos/HeatEquation/HeatEquation.jl:89 [inlined]
 [40] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
 [41] (::Zygote.var"#78#79"{Zygote.Pullback{Tuple{}, Tuple{}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface.jl:91
 [42] gradient(f::Function, args::Float64)
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface.jl:148
 [43] top-level scope
    @ ~/Documents/Work/dev/DecapodeCalibrateDemos/HeatEquation/HeatEquation.jl:94
Some type information was truncated. Use `show(err)` to see complete types.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant