MRP in RxInfer
using RxInfer, ReactiveMP, Random, Plots, StableRNGs
using LinearAlgebra, StatsPlots, LaTeXStrings, Combinatorics
#79 wrap
#uuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuu
#116 wrap
#uuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuu
partitions([1,2,3]) |> Iterators.flatten |> collect
10-element Vector{Vector{Int64}}:
[1, 2, 3]
[1, 2]
[3]
[1, 3]
[2]
[1]
[2, 3]
[1]
[2]
[3]
Generative process¶
function combine_pred_mat!(X)
A = zeros(J)
inds = LinearIndices(A) # inds[5,4,1] -> 1020
for j in CartesianIndices(J) # j = [1,1,1] ... [5,4,51]
X[inds[j], :] = [1, X¹[j[1]], X³[j[3],1], X³[j[3],2]]
end
end
combine_pred_mat! (generic function with 1 method)
rng = StableRNG(42)
ν = 1
σ₀ = 1
# income level X ethnicity X state
J = (5,4,51)
α = randn(rng, prod(J)) #cell specific random effect
β = [0.5, -0.05, 1., -.08e-5] #[const, inc level, prev rep vote, state inc]
X¹ = 1:5 # income levels
# X² = ones(4) # Note const terms added in with combine_pred_mat!
# Note: absolutely no interaction between income and rep vote share below
X³ = hcat(rand(rng, Uniform(0,1), 51), 12e3.*randn(rng, 51).+65e3) #rep prev vote share, wage yearly
X = ones(prod(J), 1 + 1 + 0 + 2) # const + (L₁ - 1) + (L₂ - 1) + (L₃ - 1) no interactions yet
X |> combine_pred_mat!
#~100k voters per cell to expect avg 100 mil potential voters
N = rand(rng, Uniform(40e3, 160e3),prod(J)) .|> round .|> Int
logits = (X * β) .+ α
θ = 1 ./ (1 .+ exp.(-logits))
function generate_synthetic_turnout_data(
J::Tuple{Int64, Int64, Int64},
θ::Vector{Float64}
)
cell_turnouts = [rand(rng, Binomial(N[j], θ[j])) for j in 1:prod(J)]
return cell_turnouts
end
y = generate_synthetic_turnout_data(J, θ);
InverseGamma(ν/2, ν * σ₀^2/2) # this isScale Inverse Chi Squared distribution as Gelman spec-d
InverseGamma{Float64}(
invd: Gamma{Float64}(α=0.5, θ=2.0)
θ: 0.5
)
Model spec¶
?BinomialPolya
search: BinomialPolya BinomialPolyaMeta Binomial binomial
BinomialPolya
A node type representing a Binomial likelihood with linear predictor through logistic. A Normal prior on the weights is used. The prior is augmented with a PolyaGamma distribution, which is used for modeling count data with overdispersion. This implementation follows the PolyaGamma augmentation scheme for Bayesian inference. Can be used for Binomial regression.
How many $\alpha$-s? $J = (5,4,51)$ thus without interactions there would be $\alpha^0$ (i.e 1) and then $5+4+51$ different random variables.
ci = CartesianIndices(zeros(J))
@model function binomial_model(y::Vector{<:Real})
print("hello")
β ~ MvNormalMeanCovariance(zeros(4), diageye(4)) #Gelman spec-d uniform, can't do that here
print("hello2")
σ⁰ ~ InverseGamma(ν/2, ν * σ₀^2/2) # this is Scale Inverse Chi Squared distribution as Gelman spec-d
print("hello3")
α⁰ ~ Normal(0,σ⁰)
print("hello4")
αⱼ₁[1:3] .~ Normal(0,σ⁰)
αⱼ₂[1:4] .~ Normal(0,σ⁰)
αⱼ₃[1:51] .~ Normal(0,σ⁰)
print("hello5")
for j in eachindex(y)
j₁,j₂,j₃ = ci[j] |> Tuple
Σα[j] ~ α⁰ + αⱼ₁[j₁] + αⱼ₂[j₂] + αⱼ₃[j₃]
# Currently trying to add to const term instead of TODO.
# βj := β + vcat(Σα[j],zeros(3)) #TODO: needs expanded β with zeros. Expanded X with another ones column.
y[j] ~ BinomialPolya(X[j], N[j], β) where {
dependencies = RequireMessageFunctionalDependencies(
β = MvNormalMeanCovariance(mean(β), cov(β))
)
}
end
end
┌ Warning: Type annotation found in interface y::Vector{<:Real}. While this will check that y is an Vector{<:Real}, dynamic creation of submodels using multiple dispatch is not supported. └ @ GraphPPL ~/.julia/packages/GraphPPL/xPNyo/src/model_macro.jl:684
ci = CartesianIndices(zeros(J))
@model function binomial_model(y::Vector{<:Real})
β ~ MvNormalMeanCovariance(zeros(4), diageye(4)) #Gelman spec-d uniform, can't do that here
print("hello2")
σ⁰ ~ InverseGamma(ν/2, ν * σ₀^2/2) # this is Scale Inverse Chi Squared distribution as Gelman spec-d
print("hello3")
α⁰ ~ Normal(0,σ⁰)
print("hello4")
αⱼ ~ Normal(0,σ⁰)
print("hello5")
for j in eachindex(y)
j₁,j₂,j₃ = ci[j] |> Tuple
Σα[j] ~ α⁰ + αⱼ
# Currently trying to add to const term instead of TODO.
# βj := β + vcat(Σα[j],zeros(3)) #TODO: needs expanded β with zeros. Expanded X with another ones column.
y[j] ~ BinomialPolya(X[j], N[j], β) where {
dependencies = RequireMessageFunctionalDependencies(
β = MvNormalMeanCovariance(zeros(4), diageye(4))
)
}
end
end
model_generator = binomial_model() | (y = y,)
model_to_plot = RxInfer.getmodel(RxInfer.create_model(model_generator))
GraphViz.load(model_to_plot, strategy = :simple, layout="dot", width=5.,height=5.)
┌ Warning: Type annotation found in interface y::Vector{<:Real}. While this will check that y is an Vector{<:Real}, dynamic creation of submodels using multiple dispatch is not supported. └ @ GraphPPL ~/.julia/packages/GraphPPL/xPNyo/src/model_macro.jl:684
MethodError: Cannot `convert` an object of type GraphPPL.ProxyLabel{GraphPPL.VariableRef{GraphPPL.Model{MetaGraphsNext.MetaGraph{Int64, Graphs.SimpleGraphs.SimpleGraph{Int64}, GraphPPL.NodeLabel, GraphPPL.NodeData, GraphPPL.EdgeLabel, GraphPPL.Context, MetaGraphsNext.var"#4#8", Float64}, GraphPPL.PluginsCollection{Tuple{}}, RxInfer.ReactiveMPGraphPPLBackend{Static.False}, String}, GraphPPL.Context, GraphPPL.NodeCreationOptions{@NamedTuple{kind::Symbol, factorized::Bool}}, Tuple{Nothing}, Vector{Int64}, Nothing}, Nothing, Static.True} to an object of type Vector{<:Real} The function `convert` exists, but no method is defined for this combination of argument types. Closest candidates are: convert(::Type{T}, ::T) where T @ Base Base.jl:126 convert(::Type{T}, ::AbstractArray) where T<:Array @ Base array.jl:618 convert(::Type{T}, ::Factorization) where T<:AbstractArray @ LinearAlgebra ~/.julia/juliaup/julia-1.11.2+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/LinearAlgebra/src/factorization.jl:104 ... Stacktrace: [1] add_terminated_submodel!(__model__::GraphPPL.Model{MetaGraphsNext.MetaGraph{Int64, Graphs.SimpleGraphs.SimpleGraph{Int64}, GraphPPL.NodeLabel, GraphPPL.NodeData, GraphPPL.EdgeLabel, GraphPPL.Context, MetaGraphsNext.var"#4#8", Float64}, GraphPPL.PluginsCollection{Tuple{}}, RxInfer.ReactiveMPGraphPPLBackend{Static.False}, String}, __context__::GraphPPL.Context, __options__::GraphPPL.NodeCreationOptions{@NamedTuple{created_by::GraphPPL.var"#95#96"}}, ::typeof(binomial_model), __interfaces__::@NamedTuple{y::GraphPPL.ProxyLabel{GraphPPL.VariableRef{GraphPPL.Model{MetaGraphsNext.MetaGraph{Int64, Graphs.SimpleGraphs.SimpleGraph{Int64}, GraphPPL.NodeLabel, GraphPPL.NodeData, GraphPPL.EdgeLabel, GraphPPL.Context, MetaGraphsNext.var"#4#8", Float64}, GraphPPL.PluginsCollection{Tuple{}}, RxInfer.ReactiveMPGraphPPLBackend{Static.False}, String}, GraphPPL.Context, GraphPPL.NodeCreationOptions{@NamedTuple{kind::Symbol, factorized::Bool}}, Tuple{Nothing}, Vector{Int64}, Nothing}, Nothing, Static.True}}, ::Static.StaticInt{1}) @ Main ~/.julia/packages/GraphPPL/xPNyo/src/model_macro.jl:700 [2] add_terminated_submodel!(model::GraphPPL.Model{MetaGraphsNext.MetaGraph{Int64, Graphs.SimpleGraphs.SimpleGraph{Int64}, GraphPPL.NodeLabel, GraphPPL.NodeData, GraphPPL.EdgeLabel, GraphPPL.Context, MetaGraphsNext.var"#4#8", Float64}, GraphPPL.PluginsCollection{Tuple{}}, RxInfer.ReactiveMPGraphPPLBackend{Static.False}, String}, context::GraphPPL.Context, options::GraphPPL.NodeCreationOptions{@NamedTuple{created_by::GraphPPL.var"#95#96"}}, fform::Function, interfaces::@NamedTuple{y::GraphPPL.ProxyLabel{GraphPPL.VariableRef{GraphPPL.Model{MetaGraphsNext.MetaGraph{Int64, Graphs.SimpleGraphs.SimpleGraph{Int64}, GraphPPL.NodeLabel, GraphPPL.NodeData, GraphPPL.EdgeLabel, GraphPPL.Context, MetaGraphsNext.var"#4#8", Float64}, GraphPPL.PluginsCollection{Tuple{}}, RxInfer.ReactiveMPGraphPPLBackend{Static.False}, String}, GraphPPL.Context, GraphPPL.NodeCreationOptions{@NamedTuple{kind::Symbol, factorized::Bool}}, Tuple{Nothing}, Vector{Int64}, Nothing}, Nothing, Static.True}}) @ GraphPPL ~/.julia/packages/GraphPPL/xPNyo/src/graph_engine.jl:2118 [3] add_terminated_submodel!(model::GraphPPL.Model{MetaGraphsNext.MetaGraph{Int64, Graphs.SimpleGraphs.SimpleGraph{Int64}, GraphPPL.NodeLabel, GraphPPL.NodeData, GraphPPL.EdgeLabel, GraphPPL.Context, MetaGraphsNext.var"#4#8", Float64}, GraphPPL.PluginsCollection{Tuple{}}, RxInfer.ReactiveMPGraphPPLBackend{Static.False}, String}, context::GraphPPL.Context, fform::Function, interfaces::@NamedTuple{y::GraphPPL.ProxyLabel{GraphPPL.VariableRef{GraphPPL.Model{MetaGraphsNext.MetaGraph{Int64, Graphs.SimpleGraphs.SimpleGraph{Int64}, GraphPPL.NodeLabel, GraphPPL.NodeData, GraphPPL.EdgeLabel, GraphPPL.Context, MetaGraphsNext.var"#4#8", Float64}, GraphPPL.PluginsCollection{Tuple{}}, RxInfer.ReactiveMPGraphPPLBackend{Static.False}, String}, GraphPPL.Context, GraphPPL.NodeCreationOptions{@NamedTuple{kind::Symbol, factorized::Bool}}, Tuple{Nothing}, Vector{Int64}, Nothing}, Nothing, Static.True}}) @ GraphPPL ~/.julia/packages/GraphPPL/xPNyo/src/graph_engine.jl:2114 [4] add_toplevel_model!(model::GraphPPL.Model{MetaGraphsNext.MetaGraph{Int64, Graphs.SimpleGraphs.SimpleGraph{Int64}, GraphPPL.NodeLabel, GraphPPL.NodeData, GraphPPL.EdgeLabel, GraphPPL.Context, MetaGraphsNext.var"#4#8", Float64}, GraphPPL.PluginsCollection{Tuple{}}, RxInfer.ReactiveMPGraphPPLBackend{Static.False}, String}, context::GraphPPL.Context, fform::Function, interfaces::@NamedTuple{y::GraphPPL.ProxyLabel{GraphPPL.VariableRef{GraphPPL.Model{MetaGraphsNext.MetaGraph{Int64, Graphs.SimpleGraphs.SimpleGraph{Int64}, GraphPPL.NodeLabel, GraphPPL.NodeData, GraphPPL.EdgeLabel, GraphPPL.Context, MetaGraphsNext.var"#4#8", Float64}, GraphPPL.PluginsCollection{Tuple{}}, RxInfer.ReactiveMPGraphPPLBackend{Static.False}, String}, GraphPPL.Context, GraphPPL.NodeCreationOptions{@NamedTuple{kind::Symbol, factorized::Bool}}, Tuple{Nothing}, Vector{Int64}, Nothing}, Nothing, Static.True}}) @ GraphPPL ~/.julia/packages/GraphPPL/xPNyo/src/graph_engine.jl:2134 [5] create_model(callback::RxInfer.var"#53#55"{@NamedTuple{y::Vector{Int64}}}, generator::GraphPPL.ModelGenerator{typeof(binomial_model), @Kwargs{}, GraphPPL.PluginsCollection{Tuple{}}, RxInfer.ReactiveMPGraphPPLBackend{Static.False}, String}) @ GraphPPL ~/.julia/packages/GraphPPL/xPNyo/src/model_generator.jl:159 [6] __infer_create_factor_graph_model(generator::GraphPPL.ModelGenerator{typeof(binomial_model), @Kwargs{}, GraphPPL.PluginsCollection{Tuple{}}, RxInfer.ReactiveMPGraphPPLBackend{Static.False}, String}, conditioned_on::@NamedTuple{y::Vector{Int64}}) @ RxInfer ~/.julia/packages/RxInfer/dt0ny/src/model/model.jl:140 [7] create_model(generator::RxInfer.ConditionedModelGenerator{GraphPPL.ModelGenerator{typeof(binomial_model), @Kwargs{}, GraphPPL.PluginsCollection{Tuple{}}, RxInfer.ReactiveMPGraphPPLBackend{Static.False}, String}, @NamedTuple{y::Vector{Int64}}}) @ RxInfer ~/.julia/packages/RxInfer/dt0ny/src/model/model.jl:128 [8] top-level scope @ In[24]:25
Inference¶
results = infer(
model = binomial_model(),
data = (y=y,),
iterations = 30,
free_energy = true,
showprogress = true,
options = (
limit_stack_depth = 100, # to prevent stack-overflow errors
)
)
MethodError: Cannot `convert` an object of type GraphPPL.ProxyLabel{GraphPPL.VariableRef{GraphPPL.Model{MetaGraphsNext.MetaGraph{Int64, Graphs.SimpleGraphs.SimpleGraph{Int64}, GraphPPL.NodeLabel, GraphPPL.NodeData, GraphPPL.EdgeLabel, GraphPPL.Context, MetaGraphsNext.var"#4#8", Float64}, GraphPPL.PluginsCollection{Tuple{GraphPPL.VariationalConstraintsPlugin{GraphPPL.NoConstraints}, GraphPPL.MetaPlugin{GraphPPL.MetaSpecification}, RxInfer.InitializationPlugin{RxInfer.NoInit}, RxInfer.ReactiveMPInferencePlugin{RxInfer.ReactiveMPInferenceOptions{RxInfer.LimitStackScheduler, Nothing, Nothing}}, RxInfer.ReactiveMPFreeEnergyPlugin{RxInfer.BetheFreeEnergy{Real, SkipInitial, AsapScheduler}}}}, RxInfer.ReactiveMPGraphPPLBackend{Static.False}, String}, GraphPPL.Context, GraphPPL.NodeCreationOptions{@NamedTuple{kind::Symbol, factorized::Bool}}, Tuple{Nothing}, Vector{Int64}, Nothing}, Nothing, Static.True} to an object of type Vector{<:Real} The function `convert` exists, but no method is defined for this combination of argument types. Closest candidates are: convert(::Type{T}, ::T) where T @ Base Base.jl:126 convert(::Type{T}, ::AbstractArray) where T<:Array @ Base array.jl:618 convert(::Type{T}, ::Factorization) where T<:AbstractArray @ LinearAlgebra ~/.julia/juliaup/julia-1.11.2+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/LinearAlgebra/src/factorization.jl:104 ... Stacktrace: [1] add_terminated_submodel!(__model__::GraphPPL.Model{MetaGraphsNext.MetaGraph{Int64, Graphs.SimpleGraphs.SimpleGraph{Int64}, GraphPPL.NodeLabel, GraphPPL.NodeData, GraphPPL.EdgeLabel, GraphPPL.Context, MetaGraphsNext.var"#4#8", Float64}, GraphPPL.PluginsCollection{Tuple{GraphPPL.VariationalConstraintsPlugin{GraphPPL.NoConstraints}, GraphPPL.MetaPlugin{GraphPPL.MetaSpecification}, RxInfer.InitializationPlugin{RxInfer.NoInit}, RxInfer.ReactiveMPInferencePlugin{RxInfer.ReactiveMPInferenceOptions{RxInfer.LimitStackScheduler, Nothing, Nothing}}, RxInfer.ReactiveMPFreeEnergyPlugin{RxInfer.BetheFreeEnergy{Real, SkipInitial, AsapScheduler}}}}, RxInfer.ReactiveMPGraphPPLBackend{Static.False}, String}, __context__::GraphPPL.Context, __options__::GraphPPL.NodeCreationOptions{@NamedTuple{created_by::GraphPPL.var"#95#96"}}, ::typeof(binomial_model), __interfaces__::@NamedTuple{y::GraphPPL.ProxyLabel{GraphPPL.VariableRef{GraphPPL.Model{MetaGraphsNext.MetaGraph{Int64, Graphs.SimpleGraphs.SimpleGraph{Int64}, GraphPPL.NodeLabel, GraphPPL.NodeData, GraphPPL.EdgeLabel, GraphPPL.Context, MetaGraphsNext.var"#4#8", Float64}, GraphPPL.PluginsCollection{Tuple{GraphPPL.VariationalConstraintsPlugin{GraphPPL.NoConstraints}, GraphPPL.MetaPlugin{GraphPPL.MetaSpecification}, RxInfer.InitializationPlugin{RxInfer.NoInit}, RxInfer.ReactiveMPInferencePlugin{RxInfer.ReactiveMPInferenceOptions{RxInfer.LimitStackScheduler, Nothing, Nothing}}, RxInfer.ReactiveMPFreeEnergyPlugin{RxInfer.BetheFreeEnergy{Real, SkipInitial, AsapScheduler}}}}, RxInfer.ReactiveMPGraphPPLBackend{Static.False}, String}, GraphPPL.Context, GraphPPL.NodeCreationOptions{@NamedTuple{kind::Symbol, factorized::Bool}}, Tuple{Nothing}, Vector{Int64}, Nothing}, Nothing, Static.True}}, ::Static.StaticInt{1}) @ Main ~/.julia/packages/GraphPPL/xPNyo/src/model_macro.jl:700 [2] add_terminated_submodel!(model::GraphPPL.Model{MetaGraphsNext.MetaGraph{Int64, Graphs.SimpleGraphs.SimpleGraph{Int64}, GraphPPL.NodeLabel, GraphPPL.NodeData, GraphPPL.EdgeLabel, GraphPPL.Context, MetaGraphsNext.var"#4#8", Float64}, GraphPPL.PluginsCollection{Tuple{GraphPPL.VariationalConstraintsPlugin{GraphPPL.NoConstraints}, GraphPPL.MetaPlugin{GraphPPL.MetaSpecification}, RxInfer.InitializationPlugin{RxInfer.NoInit}, RxInfer.ReactiveMPInferencePlugin{RxInfer.ReactiveMPInferenceOptions{RxInfer.LimitStackScheduler, Nothing, Nothing}}, RxInfer.ReactiveMPFreeEnergyPlugin{RxInfer.BetheFreeEnergy{Real, SkipInitial, AsapScheduler}}}}, RxInfer.ReactiveMPGraphPPLBackend{Static.False}, String}, context::GraphPPL.Context, options::GraphPPL.NodeCreationOptions{@NamedTuple{created_by::GraphPPL.var"#95#96"}}, fform::Function, interfaces::@NamedTuple{y::GraphPPL.ProxyLabel{GraphPPL.VariableRef{GraphPPL.Model{MetaGraphsNext.MetaGraph{Int64, Graphs.SimpleGraphs.SimpleGraph{Int64}, GraphPPL.NodeLabel, GraphPPL.NodeData, GraphPPL.EdgeLabel, GraphPPL.Context, MetaGraphsNext.var"#4#8", Float64}, GraphPPL.PluginsCollection{Tuple{GraphPPL.VariationalConstraintsPlugin{GraphPPL.NoConstraints}, GraphPPL.MetaPlugin{GraphPPL.MetaSpecification}, RxInfer.InitializationPlugin{RxInfer.NoInit}, RxInfer.ReactiveMPInferencePlugin{RxInfer.ReactiveMPInferenceOptions{RxInfer.LimitStackScheduler, Nothing, Nothing}}, RxInfer.ReactiveMPFreeEnergyPlugin{RxInfer.BetheFreeEnergy{Real, SkipInitial, AsapScheduler}}}}, RxInfer.ReactiveMPGraphPPLBackend{Static.False}, String}, GraphPPL.Context, GraphPPL.NodeCreationOptions{@NamedTuple{kind::Symbol, factorized::Bool}}, Tuple{Nothing}, Vector{Int64}, Nothing}, Nothing, Static.True}}) @ GraphPPL ~/.julia/packages/GraphPPL/xPNyo/src/graph_engine.jl:2118 [3] add_terminated_submodel!(model::GraphPPL.Model{MetaGraphsNext.MetaGraph{Int64, Graphs.SimpleGraphs.SimpleGraph{Int64}, GraphPPL.NodeLabel, GraphPPL.NodeData, GraphPPL.EdgeLabel, GraphPPL.Context, MetaGraphsNext.var"#4#8", Float64}, GraphPPL.PluginsCollection{Tuple{GraphPPL.VariationalConstraintsPlugin{GraphPPL.NoConstraints}, GraphPPL.MetaPlugin{GraphPPL.MetaSpecification}, RxInfer.InitializationPlugin{RxInfer.NoInit}, RxInfer.ReactiveMPInferencePlugin{RxInfer.ReactiveMPInferenceOptions{RxInfer.LimitStackScheduler, Nothing, Nothing}}, RxInfer.ReactiveMPFreeEnergyPlugin{RxInfer.BetheFreeEnergy{Real, SkipInitial, AsapScheduler}}}}, RxInfer.ReactiveMPGraphPPLBackend{Static.False}, String}, context::GraphPPL.Context, fform::Function, interfaces::@NamedTuple{y::GraphPPL.ProxyLabel{GraphPPL.VariableRef{GraphPPL.Model{MetaGraphsNext.MetaGraph{Int64, Graphs.SimpleGraphs.SimpleGraph{Int64}, GraphPPL.NodeLabel, GraphPPL.NodeData, GraphPPL.EdgeLabel, GraphPPL.Context, MetaGraphsNext.var"#4#8", Float64}, GraphPPL.PluginsCollection{Tuple{GraphPPL.VariationalConstraintsPlugin{GraphPPL.NoConstraints}, GraphPPL.MetaPlugin{GraphPPL.MetaSpecification}, RxInfer.InitializationPlugin{RxInfer.NoInit}, RxInfer.ReactiveMPInferencePlugin{RxInfer.ReactiveMPInferenceOptions{RxInfer.LimitStackScheduler, Nothing, Nothing}}, RxInfer.ReactiveMPFreeEnergyPlugin{RxInfer.BetheFreeEnergy{Real, SkipInitial, AsapScheduler}}}}, RxInfer.ReactiveMPGraphPPLBackend{Static.False}, String}, GraphPPL.Context, GraphPPL.NodeCreationOptions{@NamedTuple{kind::Symbol, factorized::Bool}}, Tuple{Nothing}, Vector{Int64}, Nothing}, Nothing, Static.True}}) @ GraphPPL ~/.julia/packages/GraphPPL/xPNyo/src/graph_engine.jl:2114 [4] add_toplevel_model!(model::GraphPPL.Model{MetaGraphsNext.MetaGraph{Int64, Graphs.SimpleGraphs.SimpleGraph{Int64}, GraphPPL.NodeLabel, GraphPPL.NodeData, GraphPPL.EdgeLabel, GraphPPL.Context, MetaGraphsNext.var"#4#8", Float64}, GraphPPL.PluginsCollection{Tuple{GraphPPL.VariationalConstraintsPlugin{GraphPPL.NoConstraints}, GraphPPL.MetaPlugin{GraphPPL.MetaSpecification}, RxInfer.InitializationPlugin{RxInfer.NoInit}, RxInfer.ReactiveMPInferencePlugin{RxInfer.ReactiveMPInferenceOptions{RxInfer.LimitStackScheduler, Nothing, Nothing}}, RxInfer.ReactiveMPFreeEnergyPlugin{RxInfer.BetheFreeEnergy{Real, SkipInitial, AsapScheduler}}}}, RxInfer.ReactiveMPGraphPPLBackend{Static.False}, String}, context::GraphPPL.Context, fform::Function, interfaces::@NamedTuple{y::GraphPPL.ProxyLabel{GraphPPL.VariableRef{GraphPPL.Model{MetaGraphsNext.MetaGraph{Int64, Graphs.SimpleGraphs.SimpleGraph{Int64}, GraphPPL.NodeLabel, GraphPPL.NodeData, GraphPPL.EdgeLabel, GraphPPL.Context, MetaGraphsNext.var"#4#8", Float64}, GraphPPL.PluginsCollection{Tuple{GraphPPL.VariationalConstraintsPlugin{GraphPPL.NoConstraints}, GraphPPL.MetaPlugin{GraphPPL.MetaSpecification}, RxInfer.InitializationPlugin{RxInfer.NoInit}, RxInfer.ReactiveMPInferencePlugin{RxInfer.ReactiveMPInferenceOptions{RxInfer.LimitStackScheduler, Nothing, Nothing}}, RxInfer.ReactiveMPFreeEnergyPlugin{RxInfer.BetheFreeEnergy{Real, SkipInitial, AsapScheduler}}}}, RxInfer.ReactiveMPGraphPPLBackend{Static.False}, String}, GraphPPL.Context, GraphPPL.NodeCreationOptions{@NamedTuple{kind::Symbol, factorized::Bool}}, Tuple{Nothing}, Vector{Int64}, Nothing}, Nothing, Static.True}}) @ GraphPPL ~/.julia/packages/GraphPPL/xPNyo/src/graph_engine.jl:2134 [5] create_model(callback::RxInfer.var"#53#55"{@NamedTuple{y::Vector{Int64}}}, generator::GraphPPL.ModelGenerator{typeof(binomial_model), @Kwargs{}, GraphPPL.PluginsCollection{Tuple{GraphPPL.VariationalConstraintsPlugin{GraphPPL.NoConstraints}, GraphPPL.MetaPlugin{GraphPPL.MetaSpecification}, RxInfer.InitializationPlugin{RxInfer.NoInit}, RxInfer.ReactiveMPInferencePlugin{RxInfer.ReactiveMPInferenceOptions{RxInfer.LimitStackScheduler, Nothing, Nothing}}, RxInfer.ReactiveMPFreeEnergyPlugin{RxInfer.BetheFreeEnergy{Real, SkipInitial, AsapScheduler}}}}, RxInfer.ReactiveMPGraphPPLBackend{Static.False}, String}) @ GraphPPL ~/.julia/packages/GraphPPL/xPNyo/src/model_generator.jl:159 [6] __infer_create_factor_graph_model(generator::GraphPPL.ModelGenerator{typeof(binomial_model), @Kwargs{}, GraphPPL.PluginsCollection{Tuple{GraphPPL.VariationalConstraintsPlugin{GraphPPL.NoConstraints}, GraphPPL.MetaPlugin{GraphPPL.MetaSpecification}, RxInfer.InitializationPlugin{RxInfer.NoInit}, RxInfer.ReactiveMPInferencePlugin{RxInfer.ReactiveMPInferenceOptions{RxInfer.LimitStackScheduler, Nothing, Nothing}}, RxInfer.ReactiveMPFreeEnergyPlugin{RxInfer.BetheFreeEnergy{Real, SkipInitial, AsapScheduler}}}}, RxInfer.ReactiveMPGraphPPLBackend{Static.False}, String}, conditioned_on::@NamedTuple{y::Vector{Int64}}) @ RxInfer ~/.julia/packages/RxInfer/dt0ny/src/model/model.jl:140 [7] create_model(generator::RxInfer.ConditionedModelGenerator{GraphPPL.ModelGenerator{typeof(binomial_model), @Kwargs{}, GraphPPL.PluginsCollection{Tuple{GraphPPL.VariationalConstraintsPlugin{GraphPPL.NoConstraints}, GraphPPL.MetaPlugin{GraphPPL.MetaSpecification}, RxInfer.InitializationPlugin{RxInfer.NoInit}, RxInfer.ReactiveMPInferencePlugin{RxInfer.ReactiveMPInferenceOptions{RxInfer.LimitStackScheduler, Nothing, Nothing}}, RxInfer.ReactiveMPFreeEnergyPlugin{RxInfer.BetheFreeEnergy{Real, SkipInitial, AsapScheduler}}}}, RxInfer.ReactiveMPGraphPPLBackend{Static.False}, String}, @NamedTuple{y::Vector{Int64}}}) @ RxInfer ~/.julia/packages/RxInfer/dt0ny/src/model/model.jl:128 [8] batch_inference(; model::GraphPPL.ModelGenerator{typeof(binomial_model), @Kwargs{}, GraphPPL.PluginsCollection{Tuple{}}, RxInfer.ReactiveMPGraphPPLBackend{Static.False}, String}, data::@NamedTuple{y::Vector{Int64}}, initialization::Nothing, constraints::Nothing, meta::Nothing, options::@NamedTuple{limit_stack_depth::Int64}, returnvars::Nothing, predictvars::Nothing, iterations::Int64, free_energy::Bool, free_energy_diagnostics::Tuple{RxInfer.ObjectiveDiagnosticCheckNaNs, RxInfer.ObjectiveDiagnosticCheckInfs}, allow_node_contraction::Bool, showprogress::Bool, callbacks::Nothing, addons::Nothing, postprocess::DefaultPostprocess, warn::Bool, catch_exception::Bool) @ RxInfer ~/.julia/packages/RxInfer/dt0ny/src/inference/batch.jl:203 [9] batch_inference @ ~/.julia/packages/RxInfer/dt0ny/src/inference/batch.jl:96 [inlined] [10] #288 @ ~/.julia/packages/RxInfer/dt0ny/src/inference/inference.jl:532 [inlined] [11] with_session(f::RxInfer.var"#288#290"{GraphPPL.ModelGenerator{typeof(binomial_model), @Kwargs{}, GraphPPL.PluginsCollection{Tuple{}}, RxInfer.ReactiveMPGraphPPLBackend{Static.False}, String}, @NamedTuple{y::Vector{Int64}}, Nothing, Nothing, Nothing, Nothing, Nothing, @NamedTuple{limit_stack_depth::Int64}, Nothing, Nothing, Nothing, Nothing, Int64, Bool, Tuple{RxInfer.ObjectiveDiagnosticCheckNaNs, RxInfer.ObjectiveDiagnosticCheckInfs}, Bool, Bool, Bool, Nothing, Nothing, DefaultPostprocess, Nothing, Bool, Bool, Bool}, session::RxInfer.Session, label::Symbol) @ RxInfer ~/.julia/packages/RxInfer/dt0ny/src/session.jl:253 [12] #infer#287 @ ~/.julia/packages/RxInfer/dt0ny/src/inference/inference.jl:502 [inlined] [13] top-level scope @ In[19]:1
Tutorial code¶
?cov
search: cov cor cot mov cos coth chop cotd acos cond cosd conj copy invcov cosh
cov(x::AbstractVector; corrected::Bool=true)
Compute the variance of the vector x. If corrected is true (the default) then the sum is scaled with n-1, whereas the sum is scaled with n if corrected is false where n = length(x).
cov(X::AbstractMatrix; dims::Int=1, corrected::Bool=true)
Compute the covariance matrix of the matrix X along the dimension dims. If corrected is true (the default) then the sum is scaled with n-1, whereas the sum is scaled with n if corrected is false where n = size(X, dims).
cov(x::AbstractVector, y::AbstractVector; corrected::Bool=true)
Compute the covariance between the vectors x and y. If corrected is true (the default), computes $\frac{1}{n-1}\sum_{i=1}^n (x_i-\bar x) (y_i-\bar y)^*$ where $*$ denotes the complex conjugate and n = length(x) = length(y). If corrected is false, computes $\frac{1}{n}\sum_{i=1}^n (x_i-\bar x) (y_i-\bar y)^*$.
cov(X::AbstractVecOrMat, Y::AbstractVecOrMat; dims::Int=1, corrected::Bool=true)
Compute the covariance between the vectors or matrices X and Y along the dimension dims. If corrected is true (the default) then the sum is scaled with n-1, whereas the sum is scaled with n if corrected is false where n = size(X, dims) = size(Y, dims).
cov(X, w::AbstractWeights, vardim=1; mean=nothing, corrected=false)
Compute the weighted covariance matrix. Similar to var and std the biased covariance matrix (corrected=false) is computed by multiplying scattermat(X, w) by $\frac{1}{\sum{w}}$ to normalize. However, the unbiased covariance matrix (corrected=true) is dependent on the type of weights used:
AnalyticWeights: $\frac{1}{\sum w - \sum {w^2} / \sum w}$FrequencyWeights: $\frac{1}{\sum{w} - 1}$ProbabilityWeights: $\frac{n}{(n - 1) \sum w}$ where $n$ equalscount(!iszero, w)Weights:ArgumentError(bias correction not supported)
cov(ce::CovarianceEstimator, x::AbstractVector; mean=nothing)
Compute a variance estimate from the observation vector x using the estimator ce.
cov(ce::CovarianceEstimator, x::AbstractVector, y::AbstractVector)
Compute the covariance of the vectors x and y using estimator ce.
cov(ce::CovarianceEstimator, X::AbstractMatrix, [w::AbstractWeights]; mean=nothing, dims::Int=1)
Compute the covariance matrix of the matrix X along dimension dims using estimator ce. A weighting vector w can be specified. The keyword argument mean can be:
nothing(default) in which case the mean is estimated and subtracted from the dataX,a precalculated mean in which case it is subtracted from the data
X. Assumingsize(X)is(N,M),meancan either be:- when
dims=1, anAbstractMatrixof size(1,M), - when
dims=2, anAbstractVectorof lengthNor anAbstractMatrixof size(N,1).
- when
cov(d::MultivariateDistribution)
Compute the covariance matrix for distribution d. (cor is provided based on cov).
cov(d::MatrixDistribution)
Compute the covariance matrix for vec(X), where X is a random matrix with distribution d.
cov(d::MatrixDistribution, flattened = Val(false))
Compute the 4-dimensional array whose (i, j, k, l) element is cov(X[i,j], X[k, l]).
cov(M::PPCA)
Returns the covariance of the model M.
cov(M::FactorAnalysis)
Returns the covariance of the model M.
X = MvNormalWeightedMeanPrecision(zeros(4), diageye(4))
params(X)
([0.0, 0.0, 0.0, 0.0], [1.0 0.0 0.0 0.0; 0.0 1.0 0.0 0.0; 0.0 0.0 1.0 0.0; 0.0 0.0 0.0 1.0])
β |> typeof
GraphPPL.ModelGenerator{typeof(binomial_model), Base.Pairs{Symbol, Array, NTuple{5, Symbol}, @NamedTuple{prior_xi::Vector{Float64}, prior_precision::Matrix{Float64}, X::Vector{Vector{Float64}}, y::Vector{Int64}, n_trials::Vector{Int64}}}, GraphPPL.PluginsCollection{Tuple{}}, RxInfer.ReactiveMPGraphPPLBackend{Static.False}, String}
?RequireMessageFunctionalDependencies
search: RequireMessageFunctionalDependencies
RequireMessageFunctionalDependencies(specifications::NamedTuple)
RequireMessageFunctionalDependencies(; specifications...)
The same as DefaultFunctionalDependencies, but in order to compute a message out of some edge also requires the inbound message on the this edge.
The specification parameter is a named tuple that contains the names of the edges and their initial messages. When a name is present in the named tuple, that indicates that the computation of the outbound message on the same edge must use the inbound message. If nothing is passed as a value in the named tuple, the initial message is not set. Note that the construction allows passing keyword argument to the constructor instead of using NamedTuple directly.
RequireMessageFunctionalDependencies(μ = vague(NormalMeanPrecision), τ = nothing)
# ^^^ ^^^
# request 'inbound' for 'x' we may do the same for 'τ',
# and initialise with `vague(...)` but here we skip initialisation
See also: ReactiveMP.DefaultFunctionalDependencies, ReactiveMP.RequireMarginalFunctionalDependencies, ReactiveMP.RequireEverythingFunctionalDependencies
function generate_synthetic_binomial_data(
n_samples::Int,
true_beta::Vector{Float64};
seed::Int=42
)
n_features = length(true_beta)
rng = StableRNG(seed)
X = randn(rng, n_samples, n_features)
n_trials = rand(rng, 5:20, n_samples)
logits = X * true_beta
prob2 = 1 ./ (1 .+ exp.(-logits))
y = [rand(rng, Binomial(n_trials[i], prob2[i])) for i in 1:n_samples]
return X, y, n_trials, prob2
end
n_samples = 10000
true_beta = [-3.0 , 2.6]
X, y, n_trials,prob2 = generate_synthetic_binomial_data(n_samples, true_beta);
X = [collect(row) for row in eachrow(X)];
@model function binomial_model(prior_xi, prior_precision, n_trials, X, y)
println("hello")
β ~ MvNormalWeightedMeanPrecision(prior_xi, prior_precision)
for i in eachindex(y)
y[i] ~ BinomialPolya(X[i], n_trials[i], β) where {
dependencies = RequireMessageFunctionalDependencies(β = MvNormalWeightedMeanPrecision(prior_xi, prior_precision))
}
end
println(β[:weightedmean_precision] |> typeof)
return β
end
β = binomial_model(prior_xi = zeros(n_features), prior_precision = diageye(n_features),X=X,y=y,n_trials=n_trials);
n_features = length(true_beta)
results = infer(
model = binomial_model(prior_xi = zeros(n_features), prior_precision = diageye(n_features),),
data = (X=X, y=y,n_trials=n_trials),
iterations = 30,
free_energy = true,
showprogress = true,
options = (
limit_stack_depth = 100, # to prevent stack-overflow errors
)
)
hello GraphPPL.NodeLabel
RuleMethodError: no method matching rule for the given arguments
Possible fix, define:
@rule BinomialPolya(:β, Marginalisation) (q_y::PointMass, q_x::PointMass, q_n::PointMass, ) = begin
return ...
end
Alternatively, consider re-specifying model using an existing rule:
BinomialPolya(m_β::Union{NormalDistributionsFamily{T}, GaussianDistributionsFamily{T}} where T)
BinomialPolya(q_x::PointMass, q_n::PointMass, q_β::Union{NormalDistributionsFamily{T}, GaussianDistributionsFamily{T}} where T)
BinomialPolya(q_y::Union{PointMass, Multinomial}, q_x::PointMass, q_n::PointMass)
Note that for marginal rules (i.e., involving q_*), the order of input types matters.
Stacktrace:
[1] (::ReactiveMP.MessageMapping{BinomialPolya, Val{:β}, Marginalisation, Nothing, Val{(:y, :x, :n)}, Nothing, Nothing, Nothing, Nothing})(messages::Nothing, marginals::Tuple{Marginal{PointMass{Int64}, Nothing}, Marginal{PointMass{Vector{Float64}}, Nothing}, Marginal{PointMass{Int64}, Nothing}})
@ ReactiveMP ~/.julia/packages/ReactiveMP/KHgtf/src/message.jl:357
[2] as_message(message::DeferredMessage{Nothing, Tuple{ReactiveMP.MarginalObservable, ReactiveMP.MarginalObservable, ReactiveMP.MarginalObservable}, ReactiveMP.MessageMapping{BinomialPolya, Val{:β}, Marginalisation, Nothing, Val{(:y, :x, :n)}, Nothing, Nothing, Nothing, Nothing}}, cache::Nothing, messages::Nothing, marginals::Tuple{Marginal{PointMass{Int64}, Nothing}, Marginal{PointMass{Vector{Float64}}, Nothing}, Marginal{PointMass{Int64}, Nothing}})
@ ReactiveMP ~/.julia/packages/ReactiveMP/KHgtf/src/message.jl:235
[3] as_message(message::DeferredMessage{Nothing, Tuple{ReactiveMP.MarginalObservable, ReactiveMP.MarginalObservable, ReactiveMP.MarginalObservable}, ReactiveMP.MessageMapping{BinomialPolya, Val{:β}, Marginalisation, Nothing, Val{(:y, :x, :n)}, Nothing, Nothing, Nothing, Nothing}}, cache::Nothing)
@ ReactiveMP ~/.julia/packages/ReactiveMP/KHgtf/src/message.jl:231
[4] as_message(message::DeferredMessage{Nothing, Tuple{ReactiveMP.MarginalObservable, ReactiveMP.MarginalObservable, ReactiveMP.MarginalObservable}, ReactiveMP.MessageMapping{BinomialPolya, Val{:β}, Marginalisation, Nothing, Val{(:y, :x, :n)}, Nothing, Nothing, Nothing, Nothing}})
@ ReactiveMP ~/.julia/packages/ReactiveMP/KHgtf/src/message.jl:223
[5] MappingRF
@ ./reduce.jl:100 [inlined]
[6] _foldl_impl(op::Base.MappingRF{typeof(as_message), Base.BottomRF{ReactiveMP.var"#16#18"{GenericProd}}}, init::Base._InitialValue, itr::Vector{AbstractMessage})
@ Base ./reduce.jl:62
[7] foldl_impl
@ ./reduce.jl:48 [inlined]
[8] mapfoldl_impl
@ ./reduce.jl:44 [inlined]
[9] mapfoldl
@ ./reduce.jl:175 [inlined]
[10] foldl
@ ./reduce.jl:198 [inlined]
[11] (::ReactiveMP.var"#15#17"{GenericProd, CompositeFormConstraint{Tuple{UnspecifiedFormConstraint, RxInfer.EnsureSupportedFunctionalForm}}})(messages::Vector{AbstractMessage})
@ ReactiveMP ~/.julia/packages/ReactiveMP/KHgtf/src/message.jl:149
[12] (::ReactiveMP.var"#119#120"{ReactiveMP.var"#15#17"{GenericProd, CompositeFormConstraint{Tuple{UnspecifiedFormConstraint, RxInfer.EnsureSupportedFunctionalForm}}}})(messages::Vector{AbstractMessage})
@ ReactiveMP ~/.julia/packages/ReactiveMP/KHgtf/src/variables/variable.jl:36
[13] next_received!(wrapper::Rocket.CollectLatestObservableWrapper{AbstractMessage, Rocket.RecentSubjectInstance{Marginal, Subject{Marginal, AsapScheduler, AsapScheduler}}, Vector{AbstractMessage}, BitVector, Vector{Teardown}, ReactiveMP.var"#119#120"{ReactiveMP.var"#15#17"{GenericProd, CompositeFormConstraint{Tuple{UnspecifiedFormConstraint, RxInfer.EnsureSupportedFunctionalForm}}}}, typeof(ReactiveMP.reset_vstatus)}, data::DeferredMessage{Nothing, Tuple{ReactiveMP.MarginalObservable, ReactiveMP.MarginalObservable, ReactiveMP.MarginalObservable}, ReactiveMP.MessageMapping{BinomialPolya, Val{:β}, Marginalisation, Nothing, Val{(:y, :x, :n)}, Nothing, Nothing, Nothing, Nothing}}, index::CartesianIndex{1})
@ Rocket ~/.julia/packages/Rocket/Qsjhz/src/observable/collected.jl:103
[14] on_next!
@ ~/.julia/packages/Rocket/Qsjhz/src/observable/collected.jl:93 [inlined]
[15] scheduled_next!(actor::Rocket.CollectLatestObservableInnerActor{AbstractMessage, CartesianIndex{1}, Rocket.CollectLatestObservableWrapper{AbstractMessage, Rocket.RecentSubjectInstance{Marginal, Subject{Marginal, AsapScheduler, AsapScheduler}}, Vector{AbstractMessage}, BitVector, Vector{Teardown}, ReactiveMP.var"#119#120"{ReactiveMP.var"#15#17"{GenericProd, CompositeFormConstraint{Tuple{UnspecifiedFormConstraint, RxInfer.EnsureSupportedFunctionalForm}}}}, typeof(ReactiveMP.reset_vstatus)}}, value::DeferredMessage{Nothing, Tuple{ReactiveMP.MarginalObservable, ReactiveMP.MarginalObservable, ReactiveMP.MarginalObservable}, ReactiveMP.MessageMapping{BinomialPolya, Val{:β}, Marginalisation, Nothing, Val{(:y, :x, :n)}, Nothing, Nothing, Nothing, Nothing}}, ::AsapScheduler)
@ Rocket ~/.julia/packages/Rocket/Qsjhz/src/schedulers/asap.jl:23
[16] on_next!(subject::Subject{AbstractMessage, AsapScheduler, AsapScheduler}, data::DeferredMessage{Nothing, Tuple{ReactiveMP.MarginalObservable, ReactiveMP.MarginalObservable, ReactiveMP.MarginalObservable}, ReactiveMP.MessageMapping{BinomialPolya, Val{:β}, Marginalisation, Nothing, Val{(:y, :x, :n)}, Nothing, Nothing, Nothing, Nothing}})
@ Rocket ~/.julia/packages/Rocket/Qsjhz/src/subjects/subject.jl:62
[17] actor_on_next!
@ ~/.julia/packages/Rocket/Qsjhz/src/actor.jl:250 [inlined]
[18] next!
@ ~/.julia/packages/Rocket/Qsjhz/src/actor.jl:202 [inlined]
[19] on_next!
@ ~/.julia/packages/Rocket/Qsjhz/src/subjects/recent.jl:62 [inlined]
[20] #7
@ ~/.julia/packages/RxInfer/dt0ny/src/rocket.jl:86 [inlined]
[21] limitstack(callback::RxInfer.var"#7#8"{Rocket.RecentSubjectInstance{AbstractMessage, Subject{AbstractMessage, AsapScheduler, AsapScheduler}}, DeferredMessage{Nothing, Tuple{ReactiveMP.MarginalObservable, ReactiveMP.MarginalObservable, ReactiveMP.MarginalObservable}, ReactiveMP.MessageMapping{BinomialPolya, Val{:β}, Marginalisation, Nothing, Val{(:y, :x, :n)}, Nothing, Nothing, Nothing, Nothing}}}, instance::RxInfer.LimitStackScheduler)
@ RxInfer ~/.julia/packages/RxInfer/dt0ny/src/rocket.jl:52
[22] scheduled_next!
@ ~/.julia/packages/RxInfer/dt0ny/src/rocket.jl:86 [inlined]
[23] actor_on_next!
@ ~/.julia/packages/Rocket/Qsjhz/src/actor.jl:255 [inlined]
[24] next!
@ ~/.julia/packages/Rocket/Qsjhz/src/actor.jl:203 [inlined]
[25] on_next!
@ ~/.julia/packages/Rocket/Qsjhz/src/observable/scheduled.jl:12 [inlined]
[26] next!
@ ~/.julia/packages/Rocket/Qsjhz/src/actor.jl:206 [inlined]
[27] on_next!
@ ~/.julia/packages/Rocket/Qsjhz/src/operators/map.jl:62 [inlined]
[28] next!
@ ~/.julia/packages/Rocket/Qsjhz/src/actor.jl:206 [inlined]
[29] next_received!
@ ~/.julia/packages/Rocket/Qsjhz/src/observable/combined.jl:101 [inlined]
[30] on_next!
@ ~/.julia/packages/Rocket/Qsjhz/src/observable/combined.jl:68 [inlined]
[31] next!
@ ~/.julia/packages/Rocket/Qsjhz/src/actor.jl:206 [inlined]
[32] next_received!
@ ~/.julia/packages/Rocket/Qsjhz/src/observable/combined_updates.jl:72 [inlined]
[33] on_next!
@ ~/.julia/packages/Rocket/Qsjhz/src/observable/combined_updates.jl:34 [inlined]
[34] scheduled_next!(actor::Rocket.CombineLatestUpdatesInnerActor{Marginal, Rocket.CombineLatestUpdatesActorWrapper{Tuple{ReactiveMP.MarginalObservable, ReactiveMP.MarginalObservable, ReactiveMP.MarginalObservable}, Rocket.CombineLatestInnerActor{Tuple{ReactiveMP.MarginalObservable, ReactiveMP.MarginalObservable, ReactiveMP.MarginalObservable}, Rocket.CombineLatestActorWrapper{Rocket.MStorage2{Nothing, Tuple{ReactiveMP.MarginalObservable, ReactiveMP.MarginalObservable, ReactiveMP.MarginalObservable}}, Rocket.MapActor{Tuple{Nothing, Tuple{ReactiveMP.MarginalObservable, ReactiveMP.MarginalObservable, ReactiveMP.MarginalObservable}}, Rocket.ScheduledActor{AbstractMessage, RxInfer.LimitStackScheduler, Rocket.RecentSubjectInstance{AbstractMessage, Subject{AbstractMessage, AsapScheduler, AsapScheduler}}}, ReactiveMP.var"#182#185"{ReactiveMP.MessageMapping{BinomialPolya, Val{:β}, Marginalisation, Nothing, Val{(:y, :x, :n)}, Nothing, Nothing, Nothing, Nothing}}}, PushNew, Rocket.UInt8UpdatesStatus}}, PushNew, Rocket.UInt8UpdatesStatus, typeof(identity), Nothing}}, value::Marginal{PointMass{Int64}, Nothing}, ::AsapScheduler)
@ Rocket ~/.julia/packages/Rocket/Qsjhz/src/schedulers/asap.jl:23
[35] on_next!(subject::Subject{Marginal, AsapScheduler, AsapScheduler}, data::Marginal{PointMass{Int64}, Nothing})
@ Rocket ~/.julia/packages/Rocket/Qsjhz/src/subjects/subject.jl:62
[36] actor_on_next!
@ ~/.julia/packages/Rocket/Qsjhz/src/actor.jl:250 [inlined]
[37] next!
@ ~/.julia/packages/Rocket/Qsjhz/src/actor.jl:202 [inlined]
[38] on_next!
@ ~/.julia/packages/Rocket/Qsjhz/src/subjects/recent.jl:62 [inlined]
[39] actor_on_next!
@ ~/.julia/packages/Rocket/Qsjhz/src/actor.jl:250 [inlined]
[40] next!
@ ~/.julia/packages/Rocket/Qsjhz/src/actor.jl:202 [inlined]
[41] on_next!
@ ~/.julia/packages/Rocket/Qsjhz/src/operators/map.jl:62 [inlined]
[42] scheduled_next!(actor::Rocket.MapActor{Message, Rocket.RecentSubjectInstance{Marginal, Subject{Marginal, AsapScheduler, AsapScheduler}}, typeof(as_marginal)}, value::Message{PointMass{Int64}, Nothing}, ::AsapScheduler)
@ Rocket ~/.julia/packages/Rocket/Qsjhz/src/schedulers/asap.jl:23
[43] on_next!(subject::Subject{Message, AsapScheduler, AsapScheduler}, data::Message{PointMass{Int64}, Nothing})
@ Rocket ~/.julia/packages/Rocket/Qsjhz/src/subjects/subject.jl:62
[44] actor_on_next!
@ ~/.julia/packages/Rocket/Qsjhz/src/actor.jl:250 [inlined]
[45] next!
@ ~/.julia/packages/Rocket/Qsjhz/src/actor.jl:202 [inlined]
[46] on_next!
@ ~/.julia/packages/Rocket/Qsjhz/src/subjects/recent.jl:62 [inlined]
[47] actor_on_next!
@ ~/.julia/packages/Rocket/Qsjhz/src/actor.jl:250 [inlined]
[48] next!
@ ~/.julia/packages/Rocket/Qsjhz/src/actor.jl:202 [inlined]
[49] update!
@ ~/.julia/packages/ReactiveMP/KHgtf/src/variables/data.jl:85 [inlined]
[50] update!
@ ~/.julia/packages/ReactiveMP/KHgtf/src/variables/data.jl:84 [inlined]
[51] #137
@ ~/.julia/packages/ReactiveMP/KHgtf/src/variables/data.jl:93 [inlined]
[52] foreach
@ ./abstractarray.jl:3187 [inlined]
[53] update!(datavars::Vector{DataVariable{Rocket.RecentSubjectInstance{Message, Subject{Message, AsapScheduler, AsapScheduler}}, ReactiveMP.MarginalObservable}}, data::Vector{Int64})
@ ReactiveMP ~/.julia/packages/ReactiveMP/KHgtf/src/variables/data.jl:92
[54] batch_inference(; model::GraphPPL.ModelGenerator{typeof(binomial_model), @Kwargs{prior_xi::Vector{Float64}, prior_precision::Matrix{Float64}}, GraphPPL.PluginsCollection{Tuple{}}, RxInfer.ReactiveMPGraphPPLBackend{Static.False}, String}, data::@NamedTuple{X::Vector{Vector{Float64}}, y::Vector{Int64}, n_trials::Vector{Int64}}, initialization::Nothing, constraints::Nothing, meta::Nothing, options::@NamedTuple{limit_stack_depth::Int64}, returnvars::Nothing, predictvars::Nothing, iterations::Int64, free_energy::Bool, free_energy_diagnostics::Tuple{RxInfer.ObjectiveDiagnosticCheckNaNs, RxInfer.ObjectiveDiagnosticCheckInfs}, allow_node_contraction::Bool, showprogress::Bool, callbacks::Nothing, addons::Nothing, postprocess::DefaultPostprocess, warn::Bool, catch_exception::Bool)
@ RxInfer ~/.julia/packages/RxInfer/dt0ny/src/inference/batch.jl:304
[55] batch_inference
@ ~/.julia/packages/RxInfer/dt0ny/src/inference/batch.jl:96 [inlined]
[56] #288
@ ~/.julia/packages/RxInfer/dt0ny/src/inference/inference.jl:532 [inlined]
[57] with_session(f::RxInfer.var"#288#290"{GraphPPL.ModelGenerator{typeof(binomial_model), @Kwargs{prior_xi::Vector{Float64}, prior_precision::Matrix{Float64}}, GraphPPL.PluginsCollection{Tuple{}}, RxInfer.ReactiveMPGraphPPLBackend{Static.False}, String}, @NamedTuple{X::Vector{Vector{Float64}}, y::Vector{Int64}, n_trials::Vector{Int64}}, Nothing, Nothing, Nothing, Nothing, Nothing, @NamedTuple{limit_stack_depth::Int64}, Nothing, Nothing, Nothing, Nothing, Int64, Bool, Tuple{RxInfer.ObjectiveDiagnosticCheckNaNs, RxInfer.ObjectiveDiagnosticCheckInfs}, Bool, Bool, Bool, Nothing, Nothing, DefaultPostprocess, Nothing, Bool, Bool, Bool}, session::RxInfer.Session, label::Symbol)
@ RxInfer ~/.julia/packages/RxInfer/dt0ny/src/session.jl:253
[58] #infer#287
@ ~/.julia/packages/RxInfer/dt0ny/src/inference/inference.jl:502 [inlined]
?GraphPPL.NodeLabel
!!! warning The following bindings may be internal; they may change or be removed in future versions:
* `GraphPPL.NodeLabel`
NodeLabel(name, global_counter::Int64)
Unique identifier for a node (factor or variable) in a probabilistic graphical model.
β[:weightedmean_precision]
MethodError: no method matching getindex(::GraphPPL.ModelGenerator{typeof(binomial_model), Base.Pairs{Symbol, Array, NTuple{5, Symbol}, @NamedTuple{prior_xi::Vector{Float64}, prior_precision::Matrix{Float64}, X::Vector{Vector{Float64}}, y::Vector{Int64}, n_trials::Vector{Int64}}}, GraphPPL.PluginsCollection{Tuple{}}, RxInfer.ReactiveMPGraphPPLBackend{Static.False}, String}, ::Symbol)
The function `getindex` exists, but no method is defined for this combination of argument types.
Stacktrace:
[1] top-level scope
@ In[102]:1
names(RxInfer)
847-element Vector{Symbol}:
Symbol("@autoupdates")
Symbol("@average_energy")
Symbol("@call_marginalrule")
Symbol("@call_rule")
Symbol("@constraints")
Symbol("@initialization")
Symbol("@logscale")
Symbol("@marginalrule")
Symbol("@meta")
Symbol("@model")
Symbol("@node")
Symbol("@rule")
:AND
⋮
:weightedmean
:weightedmean_cov
:weightedmean_invcov
:weightedmean_precision
:weightedmean_std
:weightedmean_var
:weights
:with_latest
:wsample
:wsample!
:xtlog
:zipped
?GraphPPL.@model
!!! warning The following bindings may be internal; they may change or be removed in future versions:
* `GraphPPL.@model`
@model function model_name(model_arguments)
...
end
Note that the @model macro is not exported by default and the recommended way of using it is in the combination with some inference backend. The GraphPPL package provides the DefaultGraphPPLBackend structure for plotting and test purposes, but some backends may specify different behaviour for different structures. For example, the interface names of a node Normal or its behaviour may (and should) depend on the specified backend.
The recommended way of using the GraphPPL.@model macro from other backend-based packages is to define their own @model macro, which will call the GraphPPL.model_macro_interior function with the specified backend. For example
module SamplingBasedInference
struct SamplingBasedBackend end
macro model(model_specification)
return esc(GraphPPL.model_macro_interior(SamplingBasedBackend(), model_specification))
end
end
Read more about the backend inteface in the corresponding section of the documentation.
To use GraphPPL package as a standalone package for plotting and testing, use the import GraphPPL: @model explicitly to add the @model macro to the current scope.
plot(results.free_energy,fontfamily = "Computer Modern", label="Free Energy", xlabel="Iteration", ylabel="Free Energy", title="Free Energy Convergence")
ci = CartesianIndices(zeros(J))
@model function binomial_model(y::Vector{<:Real})
β ~ MvNormalMeanCovariance(zeros(4), diageye(4)) #Gelman spec-d uniform, can't do that here
print("hello2")
σ⁰ ~ InverseGamma(ν/2, ν * σ₀^2/2) # this is Scale Inverse Chi Squared distribution as Gelman spec-d
print("hello3")
α⁰ ~ Normal(0,σ⁰)
print("hello4")
αⱼ ~ Normal(0,σ⁰)
print("hello5")
for j in eachindex(y)
j₁,j₂,j₃ = ci[j] |> Tuple
Σα[j] ~ α⁰ + αⱼ
# Currently trying to add to const term instead of TODO.
# βj := β + vcat(Σα[j],zeros(3)) #TODO: needs expanded β with zeros. Expanded X with another ones column.
y[j] ~ BinomialPolya(X[j], N[j], β) where {
dependencies = RequireMessageFunctionalDependencies(
β = MvNormalMeanCovariance(zeros(4), diageye(4))
)
}
end
end
model_generator = binomial_model() | (y = y,)
model_to_plot = RxInfer.getmodel(RxInfer.create_model(model_generator))
GraphViz.load(model_to_plot, strategy = :simple, layout="dot", width=5.,height=5.)
4
Truncated normal?¶
?DeltaMeta
search: DeltaMeta FlowMeta DeltaFn
DeltaMeta(method = ..., [ inverse = ... ])
DeltaMeta structure specifies the approximation method for the outbound messages in the DeltaFn node.
Arguments¶
method: required, the approximation method, currently supported methods areLinearization,UnscentedandCVI.inverse: optional, if no inverse provided, the backward rule will be computed based on RTS (Petersen et al. 2018; On Approximate Delta Gaussian Message Passing on Factor Graphs)
Is is also possible to pass the AbstractApproximationMethod to the meta of the delta node directly. In this case inverse is set to nothing.
using RxInfer
using Distributions
# Define a custom model with truncated normal using an indicator function
@model function truncated_normal_model(y)
# Latent variable x with Normal prior
y ~ truncated(Normal(mean = 0, var = 1), -1, 1)
end
# Example observed data
y_data = 0.5
# Instantiate and run inference
model = truncated_normal_model(;y_data)
results = infer(
model = truncated_normal_model(),
data = (y=y,),
iterations = 30,
free_energy = true,
showprogress = true,
options = (
limit_stack_depth = 100, # to prevent stack-overflow errors
)
)
println("Posterior of x with truncation constraint:")
println(result.posteriors[:x])
Delta node `truncated` requires meta specification with the `where { meta = ... }` in the `@model` macro or with the separate `@meta` specification. See documentation for the `DeltaMeta`.
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] collect_meta(::Type{DeltaFn{typeof(truncated)}}, something::Nothing)
@ ReactiveMP ~/.julia/packages/ReactiveMP/KHgtf/src/nodes/predefined/delta/delta.jl:55
[3] activate!(factornode::DeltaFnNode{typeof(truncated), FixedArguments.FixedCallable{typeof(truncated), Tuple{FixedArguments.FixedArgument{FixedArguments.FixedPosition{2}, ConstVariable}, FixedArguments.FixedArgument{FixedArguments.FixedPosition{3}, ConstVariable}}, typeof(ReactiveMP.__unpack_latest_static)}, 1, Tuple{FixedArguments.FixedArgument{FixedArguments.FixedPosition{2}, ConstVariable}, FixedArguments.FixedArgument{FixedArguments.FixedPosition{3}, ConstVariable}}, ReactiveMP.FactorNodeLocalClusters{Tuple{ReactiveMP.FactorNodeLocalMarginal, ReactiveMP.FactorNodeLocalMarginal}, Nothing}}, options::ReactiveMP.FactorNodeActivationOptions{Nothing, Nothing, Nothing, Nothing, RxInfer.LimitStackScheduler, Nothing})
@ ReactiveMP ~/.julia/packages/ReactiveMP/KHgtf/src/nodes/predefined/delta/delta.jl:184
[4] activate_rmp_factornode!(plugin::RxInfer.ReactiveMPInferencePlugin{RxInfer.ReactiveMPInferenceOptions{RxInfer.LimitStackScheduler, Nothing, Nothing}}, model::GraphPPL.Model{MetaGraphsNext.MetaGraph{Int64, Graphs.SimpleGraphs.SimpleGraph{Int64}, GraphPPL.NodeLabel, GraphPPL.NodeData, GraphPPL.EdgeLabel, GraphPPL.Context, MetaGraphsNext.var"#4#8", Float64}, GraphPPL.PluginsCollection{Tuple{GraphPPL.VariationalConstraintsPlugin{GraphPPL.NoConstraints}, GraphPPL.MetaPlugin{GraphPPL.MetaSpecification}, RxInfer.InitializationPlugin{RxInfer.NoInit}, RxInfer.ReactiveMPInferencePlugin{RxInfer.ReactiveMPInferenceOptions{RxInfer.LimitStackScheduler, Nothing, Nothing}}, RxInfer.ReactiveMPFreeEnergyPlugin{RxInfer.BetheFreeEnergy{Real, SkipInitial, AsapScheduler}}}}, RxInfer.ReactiveMPGraphPPLBackend{Static.False}, String}, nodedata::GraphPPL.NodeData, nodeproperties::GraphPPL.FactorNodeProperties{GraphPPL.NodeData})
@ RxInfer ~/.julia/packages/RxInfer/dt0ny/src/model/plugins/reactivemp_inference.jl:232
[5] #64
@ ~/.julia/packages/RxInfer/dt0ny/src/model/plugins/reactivemp_inference.jl:144 [inlined]
[6] factor_nodes(callback::RxInfer.var"#64#68"{RxInfer.ReactiveMPInferencePlugin{RxInfer.ReactiveMPInferenceOptions{RxInfer.LimitStackScheduler, Nothing, Nothing}}, GraphPPL.Model{MetaGraphsNext.MetaGraph{Int64, Graphs.SimpleGraphs.SimpleGraph{Int64}, GraphPPL.NodeLabel, GraphPPL.NodeData, GraphPPL.EdgeLabel, GraphPPL.Context, MetaGraphsNext.var"#4#8", Float64}, GraphPPL.PluginsCollection{Tuple{GraphPPL.VariationalConstraintsPlugin{GraphPPL.NoConstraints}, GraphPPL.MetaPlugin{GraphPPL.MetaSpecification}, RxInfer.InitializationPlugin{RxInfer.NoInit}, RxInfer.ReactiveMPInferencePlugin{RxInfer.ReactiveMPInferenceOptions{RxInfer.LimitStackScheduler, Nothing, Nothing}}, RxInfer.ReactiveMPFreeEnergyPlugin{RxInfer.BetheFreeEnergy{Real, SkipInitial, AsapScheduler}}}}, RxInfer.ReactiveMPGraphPPLBackend{Static.False}, String}}, model::GraphPPL.Model{MetaGraphsNext.MetaGraph{Int64, Graphs.SimpleGraphs.SimpleGraph{Int64}, GraphPPL.NodeLabel, GraphPPL.NodeData, GraphPPL.EdgeLabel, GraphPPL.Context, MetaGraphsNext.var"#4#8", Float64}, GraphPPL.PluginsCollection{Tuple{GraphPPL.VariationalConstraintsPlugin{GraphPPL.NoConstraints}, GraphPPL.MetaPlugin{GraphPPL.MetaSpecification}, RxInfer.InitializationPlugin{RxInfer.NoInit}, RxInfer.ReactiveMPInferencePlugin{RxInfer.ReactiveMPInferenceOptions{RxInfer.LimitStackScheduler, Nothing, Nothing}}, RxInfer.ReactiveMPFreeEnergyPlugin{RxInfer.BetheFreeEnergy{Real, SkipInitial, AsapScheduler}}}}, RxInfer.ReactiveMPGraphPPLBackend{Static.False}, String})
@ GraphPPL ~/.julia/packages/GraphPPL/xPNyo/src/graph_engine.jl:865
[7] postprocess_plugin
@ ~/.julia/packages/RxInfer/dt0ny/src/model/plugins/reactivemp_inference.jl:143 [inlined]
[8] (::GraphPPL.var"#97#98"{GraphPPL.Model{MetaGraphsNext.MetaGraph{Int64, Graphs.SimpleGraphs.SimpleGraph{Int64}, GraphPPL.NodeLabel, GraphPPL.NodeData, GraphPPL.EdgeLabel, GraphPPL.Context, MetaGraphsNext.var"#4#8", Float64}, GraphPPL.PluginsCollection{Tuple{GraphPPL.VariationalConstraintsPlugin{GraphPPL.NoConstraints}, GraphPPL.MetaPlugin{GraphPPL.MetaSpecification}, RxInfer.InitializationPlugin{RxInfer.NoInit}, RxInfer.ReactiveMPInferencePlugin{RxInfer.ReactiveMPInferenceOptions{RxInfer.LimitStackScheduler, Nothing, Nothing}}, RxInfer.ReactiveMPFreeEnergyPlugin{RxInfer.BetheFreeEnergy{Real, SkipInitial, AsapScheduler}}}}, RxInfer.ReactiveMPGraphPPLBackend{Static.False}, String}})(plugin::RxInfer.ReactiveMPInferencePlugin{RxInfer.ReactiveMPInferenceOptions{RxInfer.LimitStackScheduler, Nothing, Nothing}})
@ GraphPPL ~/.julia/packages/GraphPPL/xPNyo/src/graph_engine.jl:2136
[9] foreach(f::GraphPPL.var"#97#98"{GraphPPL.Model{MetaGraphsNext.MetaGraph{Int64, Graphs.SimpleGraphs.SimpleGraph{Int64}, GraphPPL.NodeLabel, GraphPPL.NodeData, GraphPPL.EdgeLabel, GraphPPL.Context, MetaGraphsNext.var"#4#8", Float64}, GraphPPL.PluginsCollection{Tuple{GraphPPL.VariationalConstraintsPlugin{GraphPPL.NoConstraints}, GraphPPL.MetaPlugin{GraphPPL.MetaSpecification}, RxInfer.InitializationPlugin{RxInfer.NoInit}, RxInfer.ReactiveMPInferencePlugin{RxInfer.ReactiveMPInferenceOptions{RxInfer.LimitStackScheduler, Nothing, Nothing}}, RxInfer.ReactiveMPFreeEnergyPlugin{RxInfer.BetheFreeEnergy{Real, SkipInitial, AsapScheduler}}}}, RxInfer.ReactiveMPGraphPPLBackend{Static.False}, String}}, itr::GraphPPL.PluginsCollection{Tuple{GraphPPL.VariationalConstraintsPlugin{GraphPPL.NoConstraints}, GraphPPL.MetaPlugin{GraphPPL.MetaSpecification}, RxInfer.InitializationPlugin{RxInfer.NoInit}, RxInfer.ReactiveMPInferencePlugin{RxInfer.ReactiveMPInferenceOptions{RxInfer.LimitStackScheduler, Nothing, Nothing}}, RxInfer.ReactiveMPFreeEnergyPlugin{RxInfer.BetheFreeEnergy{Real, SkipInitial, AsapScheduler}}}})
@ Base ./abstractarray.jl:3187
[10] add_toplevel_model!
@ ~/.julia/packages/GraphPPL/xPNyo/src/graph_engine.jl:2135 [inlined]
[11] create_model
@ ~/.julia/packages/GraphPPL/xPNyo/src/model_generator.jl:159 [inlined]
[12] __infer_create_factor_graph_model
@ ~/.julia/packages/RxInfer/dt0ny/src/model/model.jl:140 [inlined]
[13] create_model(generator::RxInfer.ConditionedModelGenerator{GraphPPL.ModelGenerator{typeof(truncated_normal_model), @Kwargs{}, GraphPPL.PluginsCollection{Tuple{GraphPPL.VariationalConstraintsPlugin{GraphPPL.NoConstraints}, GraphPPL.MetaPlugin{GraphPPL.MetaSpecification}, RxInfer.InitializationPlugin{RxInfer.NoInit}, RxInfer.ReactiveMPInferencePlugin{RxInfer.ReactiveMPInferenceOptions{RxInfer.LimitStackScheduler, Nothing, Nothing}}, RxInfer.ReactiveMPFreeEnergyPlugin{RxInfer.BetheFreeEnergy{Real, SkipInitial, AsapScheduler}}}}, RxInfer.ReactiveMPGraphPPLBackend{Static.False}, String}, @NamedTuple{y::Vector{Int64}}})
@ RxInfer ~/.julia/packages/RxInfer/dt0ny/src/model/model.jl:128
[14] batch_inference(; model::GraphPPL.ModelGenerator{typeof(truncated_normal_model), @Kwargs{}, GraphPPL.PluginsCollection{Tuple{}}, RxInfer.ReactiveMPGraphPPLBackend{Static.False}, String}, data::@NamedTuple{y::Vector{Int64}}, initialization::Nothing, constraints::Nothing, meta::Nothing, options::@NamedTuple{limit_stack_depth::Int64}, returnvars::Nothing, predictvars::Nothing, iterations::Int64, free_energy::Bool, free_energy_diagnostics::Tuple{RxInfer.ObjectiveDiagnosticCheckNaNs, RxInfer.ObjectiveDiagnosticCheckInfs}, allow_node_contraction::Bool, showprogress::Bool, callbacks::Nothing, addons::Nothing, postprocess::DefaultPostprocess, warn::Bool, catch_exception::Bool)
@ RxInfer ~/.julia/packages/RxInfer/dt0ny/src/inference/batch.jl:203
[15] batch_inference
@ ~/.julia/packages/RxInfer/dt0ny/src/inference/batch.jl:96 [inlined]
[16] #288
@ ~/.julia/packages/RxInfer/dt0ny/src/inference/inference.jl:532 [inlined]
[17] with_session(f::RxInfer.var"#288#290"{GraphPPL.ModelGenerator{typeof(truncated_normal_model), @Kwargs{}, GraphPPL.PluginsCollection{Tuple{}}, RxInfer.ReactiveMPGraphPPLBackend{Static.False}, String}, @NamedTuple{y::Vector{Int64}}, Nothing, Nothing, Nothing, Nothing, Nothing, @NamedTuple{limit_stack_depth::Int64}, Nothing, Nothing, Nothing, Nothing, Int64, Bool, Tuple{RxInfer.ObjectiveDiagnosticCheckNaNs, RxInfer.ObjectiveDiagnosticCheckInfs}, Bool, Bool, Bool, Nothing, Nothing, DefaultPostprocess, Nothing, Bool, Bool, Bool}, session::RxInfer.Session, label::Symbol)
@ RxInfer ~/.julia/packages/RxInfer/dt0ny/src/session.jl:253
[18] #infer#287
@ ~/.julia/packages/RxInfer/dt0ny/src/inference/inference.jl:502 [inlined]
[19] top-level scope
@ In[124]:16
using Distributions
d = truncated(Normal(0, 1), -1, 1)
Truncated(Normal{Float64}(μ=0.0, σ=1.0); lower=-1.0, upper=1.0)
li = LinearIndices(A)
lin_index = li[cart_index]
A[lin_index]
0.21330265646780877