Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pointpriors #663

Merged
merged 39 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
d05124c
implement pointwise_logpriors
bgctw Sep 16, 2024
4f46102
implement varwise_logpriors
bgctw Sep 17, 2024
c6653b9
remove pointwise_logpriors
bgctw Sep 17, 2024
216d50c
revert dot_assume to not explicitly resolve components of sum
bgctw Sep 17, 2024
fd8d3b2
docstring varwise_logpriores
bgctw Sep 18, 2024
5842656
integrate pointwise_loglikelihoods and varwise_logpriors by pointwise…
bgctw Sep 19, 2024
18beb57
record single prior components
bgctw Sep 21, 2024
d9945d7
forward dot_tilde_assume to tilde_assume for Multivariate
bgctw Sep 22, 2024
656a757
avoid recording prior components on leaf-prior-context
bgctw Sep 24, 2024
7aa9ebe
undeprecate pointwise_loglikelihoods and implement pointwise_prior_lo…
bgctw Sep 24, 2024
2f67c5b
drop vi instead of re-compute vi
bgctw Sep 24, 2024
9dfb9ed
include docstrings of pointwise_logdensities
bgctw Sep 24, 2024
c1939e0
Update src/pointwise_logdensities.jl remove commented code
bgctw Sep 25, 2024
790be1d
Update src/pointwise_logdensities.jl remove commented code
bgctw Sep 25, 2024
426df38
Update test/pointwise_logdensities.jl rename m to model
bgctw Sep 25, 2024
c32bf3b
Update test/pointwise_logdensities.jl remove unused code
bgctw Sep 25, 2024
6213249
Update test/pointwise_logdensities.jl rename m to model
bgctw Sep 25, 2024
3551b38
Update test/pointwise_logdensities.jl rename m to model
bgctw Sep 25, 2024
95c892b
Update src/test_utils.jl remove old code
bgctw Sep 25, 2024
a7a7e70
rename m to model
bgctw Sep 25, 2024
1653aba
JuliaFormatter
bgctw Sep 25, 2024
e4f0a1d
Merge branch 'pointpriors' of github.com:bgctw/DynamicPPL.jl into poi…
bgctw Sep 25, 2024
a99eab4
Update test/runtests.jl remove interactive code
bgctw Sep 26, 2024
64ce63a
remove demo_dot_assume_matrix_dot_observe_matrix2 testcase
bgctw Sep 26, 2024
456115c
ignore local interactive development code
bgctw Sep 26, 2024
222529a
ignore temporary directory holding local interactive development code
bgctw Sep 26, 2024
17b251a
Apply suggestions from code review: clean up comments and Imports
bgctw Sep 26, 2024
7e990f0
Apply suggestions from code review: change test of applying to chains…
bgctw Sep 26, 2024
8706f68
fix test on names in likelihood components
bgctw Sep 26, 2024
073a325
try to fix testset pointwise_logdensities chain
bgctw Sep 26, 2024
23e1711
Update test/pointwise_logdensities.jl
torfjelde Sep 26, 2024
34ae4f8
Update .gitignore
torfjelde Sep 26, 2024
1f251d1
Merge branch 'master' into pointpriors
torfjelde Sep 26, 2024
777624a
Formtating
torfjelde Sep 26, 2024
4864e60
Fixed tests
torfjelde Sep 26, 2024
4d3b0c0
Updated docs for `pointwise_logdensities` + made it a doctest not
torfjelde Sep 27, 2024
e54fa4e
Bump patch version
torfjelde Sep 27, 2024
bcd82a9
Remove blank line from `@model` in doctest to see if that fixes the
torfjelde Sep 27, 2024
cff0941
Added doctest filter to handle the `;;]` at the end of lines for matr…
torfjelde Sep 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,17 @@ For a chain of samples, one can compute the pointwise log-likelihoods of each ob
pointwise_loglikelihoods
```

Similarly, one can compute the pointwise log-priors of each sampled random variable
with [`varwise_logpriors`](@ref).
Differently from `pointwise_loglikelihoods` it reports only a
single value for `.~` assignements.
If one needs to access the parts for single indices, one can
reformulate the model to use an explicit loop instead.

```@docs
varwise_logpriors
```

For converting a chain into a format that can more easily be fed into a `Model` again, for example using `condition`, you can use [`value_iterator_from_chain`](@ref).

```@docs
Expand Down
4 changes: 3 additions & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ export AbstractVarInfo,
logprior,
logjoint,
pointwise_loglikelihoods,
pointwise_logdensities,
condition,
decondition,
fix,
Expand Down Expand Up @@ -181,14 +182,15 @@ include("varinfo.jl")
include("simple_varinfo.jl")
include("context_implementations.jl")
include("compiler.jl")
include("loglikelihoods.jl")
include("pointwise_logdensities.jl")
include("submodel_macro.jl")
include("test_utils.jl")
include("transforming.jl")
include("logdensityfunction.jl")
include("model_utils.jl")
include("extract_priors.jl")
include("values_as_in_model.jl")
include("deprecated.jl")

include("debug_utils.jl")
using .DebugUtils
Expand Down
9 changes: 9 additions & 0 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# https://invenia.github.io/blog/2022/06/17/deprecating-in-julia/

Base.@deprecate pointwise_loglikelihoods(model::Model, chain, keytype) pointwise_logdensities(
model::Model, LikelihoodContext(), chain, keytype)

Base.@deprecate pointwise_loglikelihoods(
model::Model, varinfo::AbstractVarInfo) pointwise_logdensities(
model::Model, varinfo, LikelihoodContext())
bgctw marked this conversation as resolved.
Show resolved Hide resolved

128 changes: 87 additions & 41 deletions src/loglikelihoods.jl → src/pointwise_logdensities.jl
Original file line number Diff line number Diff line change
@@ -1,83 +1,83 @@
# Context version
struct PointwiseLikelihoodContext{A,Ctx} <: AbstractContext
loglikelihoods::A
struct PointwiseLogdensityContext{A,Ctx} <: AbstractContext
logdensities::A
context::Ctx
end

function PointwiseLikelihoodContext(
function PointwiseLogdensityContext(
likelihoods=OrderedDict{VarName,Vector{Float64}}(),
context::AbstractContext=LikelihoodContext(),
context::AbstractContext=DefaultContext(),
)
return PointwiseLikelihoodContext{typeof(likelihoods),typeof(context)}(
return PointwiseLogdensityContext{typeof(likelihoods),typeof(context)}(
likelihoods, context
)
end

NodeTrait(::PointwiseLikelihoodContext) = IsParent()
childcontext(context::PointwiseLikelihoodContext) = context.context
function setchildcontext(context::PointwiseLikelihoodContext, child)
return PointwiseLikelihoodContext(context.loglikelihoods, child)
NodeTrait(::PointwiseLogdensityContext) = IsParent()
childcontext(context::PointwiseLogdensityContext) = context.context
function setchildcontext(context::PointwiseLogdensityContext, child)
return PointwiseLogdensityContext(context.logdensities, child)
end

function Base.push!(
context::PointwiseLikelihoodContext{<:AbstractDict{VarName,Vector{Float64}}},
context::PointwiseLogdensityContext{<:AbstractDict{VarName,Vector{Float64}}},
vn::VarName,
logp::Real,
)
lookup = context.loglikelihoods
lookup = context.logdensities
ℓ = get!(lookup, vn, Float64[])
return push!(ℓ, logp)
end

function Base.push!(
context::PointwiseLikelihoodContext{<:AbstractDict{VarName,Float64}},
context::PointwiseLogdensityContext{<:AbstractDict{VarName,Float64}},
vn::VarName,
logp::Real,
)
return context.loglikelihoods[vn] = logp
return context.logdensities[vn] = logp
end

function Base.push!(
context::PointwiseLikelihoodContext{<:AbstractDict{String,Vector{Float64}}},
context::PointwiseLogdensityContext{<:AbstractDict{String,Vector{Float64}}},
vn::VarName,
logp::Real,
)
lookup = context.loglikelihoods
lookup = context.logdensities
ℓ = get!(lookup, string(vn), Float64[])
return push!(ℓ, logp)
end

function Base.push!(
context::PointwiseLikelihoodContext{<:AbstractDict{String,Float64}},
context::PointwiseLogdensityContext{<:AbstractDict{String,Float64}},
vn::VarName,
logp::Real,
)
return context.loglikelihoods[string(vn)] = logp
return context.logdensities[string(vn)] = logp
end

function Base.push!(
context::PointwiseLikelihoodContext{<:AbstractDict{String,Vector{Float64}}},
context::PointwiseLogdensityContext{<:AbstractDict{String,Vector{Float64}}},
vn::String,
logp::Real,
)
lookup = context.loglikelihoods
lookup = context.logdensities
ℓ = get!(lookup, vn, Float64[])
return push!(ℓ, logp)
end

function Base.push!(
context::PointwiseLikelihoodContext{<:AbstractDict{String,Float64}},
context::PointwiseLogdensityContext{<:AbstractDict{String,Float64}},
vn::String,
logp::Real,
)
return context.loglikelihoods[vn] = logp
return context.logdensities[vn] = logp
end

function tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vi)
function tilde_observe!!(context::PointwiseLogdensityContext, right, left, vi)
# Defer literal `observe` to child-context.
return tilde_observe!!(context.context, right, left, vi)
end
function tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vn, vi)
function tilde_observe!!(context::PointwiseLogdensityContext, right, left, vn, vi)
# Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e.
# we have to intercept the call to `tilde_observe!`.
logp, vi = tilde_observe(context.context, right, left, vi)
Expand All @@ -88,11 +88,11 @@ function tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vn, v
return left, acclogp!!(vi, logp)
end

function dot_tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vi)
function dot_tilde_observe!!(context::PointwiseLogdensityContext, right, left, vi)
# Defer literal `observe` to child-context.
return dot_tilde_observe!!(context.context, right, left, vi)
end
function dot_tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vn, vi)
function dot_tilde_observe!!(context::PointwiseLogdensityContext, right, left, vn, vi)
# Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e.
# we have to intercept the call to `dot_tilde_observe!`.

Expand Down Expand Up @@ -129,8 +129,49 @@ function _pointwise_tilde_observe(
end
end

function tilde_assume(context::PointwiseLogdensityContext, right, vn, vi)
#@info "PointwiseLogdensityContext tilde_assume!! called for $vn"
value, logp, vi = tilde_assume(context.context, right, vn, vi)
#sym = DynamicPPL.getsym(vn)
new_context = acc_logp!(context, vn, logp)
return value, logp, vi
end

function dot_tilde_assume(context::PointwiseLogdensityContext, right, left, vn, vi)
#@info "PointwiseLogdensityContext dot_tilde_assume!! called for $vn"
# @show vn, left, right, typeof(context).name
value, logp, vi = dot_tilde_assume(context.context, right, left, vn, vi)
new_context = acc_logp!(context, vn, logp)
return value, logp, vi
end

function acc_logp!(context::PointwiseLogdensityContext, vn::VarName, logp)
push!(context, vn, logp)
return (context)
end

function acc_logp!(context::PointwiseLogdensityContext, vns::AbstractVector{<:VarName}, logp)
# construct a new VarName from given sequence of VarName
# assume that all items in vns have an IndexLens optic
indices = tuplejoin(map(vn -> getoptic(vn).indices, vns)...)
vn = VarName(first(vns), Accessors.IndexLens(indices))
push!(context, vn, logp)
return (context)
end

#https://discourse.julialang.org/t/efficient-tuple-concatenation/5398/8
@inline tuplejoin(x) = x
@inline tuplejoin(x, y) = (x..., y...)
@inline tuplejoin(x, y, z...) = (x..., tuplejoin(y, z...)...)

() -> begin
# code that generates julia-repl in docstring below
# using DynamicPPL, Turing
# TODO when Turing version that is compatible with DynamicPPL 0.29 becomes available
end
bgctw marked this conversation as resolved.
Show resolved Hide resolved

"""
pointwise_loglikelihoods(model::Model, chain::Chains, keytype = String)
pointwise_logdensities(model::Model, chain::Chains, keytype = String)

Runs `model` on each sample in `chain` returning a `OrderedDict{String, Matrix{Float64}}`
with keys corresponding to symbols of the observations, and values being matrices
Expand Down Expand Up @@ -184,21 +225,21 @@ julia> model = demo(randn(3), randn());

julia> chain = sample(model, MH(), 10);

julia> pointwise_loglikelihoods(model, chain)
julia> pointwise_logdensities(model, chain)
OrderedDict{String,Array{Float64,2}} with 4 entries:
"xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333]
"xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359]
"xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251]
"y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499]

julia> pointwise_loglikelihoods(model, chain, String)
julia> pointwise_logdensities(model, chain, String)
OrderedDict{String,Array{Float64,2}} with 4 entries:
"xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333]
"xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359]
"xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251]
"y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499]

julia> pointwise_loglikelihoods(model, chain, VarName)
julia> pointwise_logdensities(model, chain, VarName)
OrderedDict{VarName,Array{Float64,2}} with 4 entries:
xs[1] => [-1.42932; -2.68123; … ; -1.66333; -1.66333]
xs[2] => [-1.6724; -0.861339; … ; -1.62359; -1.62359]
Expand All @@ -217,41 +258,46 @@ julia> @model function demo(x)

julia> m = demo([1.0, ]);

julia> ℓ = pointwise_loglikelihoods(m, VarInfo(m)); first(ℓ[@varname(x[1])])
julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first(ℓ[@varname(x[1])])
-1.4189385332046727

julia> m = demo([1.0; 1.0]);

julia> ℓ = pointwise_loglikelihoods(m, VarInfo(m)); first.((ℓ[@varname(x[1])], ℓ[@varname(x[2])]))
julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first.((ℓ[@varname(x[1])], ℓ[@varname(x[2])]))
(-1.4189385332046727, -1.4189385332046727)
```

"""
function pointwise_loglikelihoods(model::Model, chain, keytype::Type{T}=String) where {T}
function pointwise_logdensities(model::Model, chain,
context::AbstractContext=DefaultContext(), keytype::Type{T}=String) where {T}
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
# Get the data by executing the model once
vi = VarInfo(model)
context = PointwiseLikelihoodContext(OrderedDict{T,Vector{Float64}}())
point_context = PointwiseLogdensityContext(OrderedDict{T,Vector{Float64}}(), context)

iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
for (sample_idx, chain_idx) in iters
# Update the values
setval!(vi, chain, sample_idx, chain_idx)

# Execute model
model(vi, context)
model(vi, point_context)
end

niters = size(chain, 1)
nchains = size(chain, 3)
loglikelihoods = OrderedDict(
logdensities = OrderedDict(
varname => reshape(logliks, niters, nchains) for
(varname, logliks) in context.loglikelihoods
(varname, logliks) in point_context.logdensities
)
return loglikelihoods
return logdensities
end

function pointwise_loglikelihoods(model::Model, varinfo::AbstractVarInfo)
context = PointwiseLikelihoodContext(OrderedDict{VarName,Vector{Float64}}())
model(varinfo, context)
return context.loglikelihoods
function pointwise_logdensities(model::Model,
varinfo::AbstractVarInfo, context::AbstractContext=DefaultContext())
point_context = PointwiseLogdensityContext(
OrderedDict{VarName,Vector{Float64}}(), context)
model(varinfo, point_context)
return point_context.logdensities
end


46 changes: 46 additions & 0 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1042,4 +1042,50 @@ function test_context_interface(context)
end
end

"""
Context that multiplies each log-prior by mod
used to test whether varwise_logpriors respects child-context.
"""
struct TestLogModifyingChildContext{T,Ctx} <: DynamicPPL.AbstractContext
mod::T
context::Ctx
end
function TestLogModifyingChildContext(
mod=1.2,
context::DynamicPPL.AbstractContext=DynamicPPL.DefaultContext(),
#OrderedDict{VarName,Vector{Float64}}(),PriorContext()),
bgctw marked this conversation as resolved.
Show resolved Hide resolved
)
return TestLogModifyingChildContext{typeof(mod),typeof(context)}(
mod, context
)
end
# Samplers call leafcontext(model.context) when evaluating log-densities
# Hence, in order to be used need to say that its a leaf-context
#DynamicPPL.NodeTrait(::TestLogModifyingChildContext) = DynamicPPL.IsParent()
DynamicPPL.NodeTrait(::TestLogModifyingChildContext) = DynamicPPL.IsLeaf()
bgctw marked this conversation as resolved.
Show resolved Hide resolved
DynamicPPL.childcontext(context::TestLogModifyingChildContext) = context.context
function DynamicPPL.setchildcontext(context::TestLogModifyingChildContext, child)
return TestLogModifyingChildContext(context.mod, child)
end
function DynamicPPL.tilde_assume(context::TestLogModifyingChildContext, right, vn, vi)
#@info "TestLogModifyingChildContext tilde_assume!! called for $vn"
bgctw marked this conversation as resolved.
Show resolved Hide resolved
value, logp, vi = DynamicPPL.tilde_assume(context.context, right, vn, vi)
return value, logp*context.mod, vi
end
function DynamicPPL.dot_tilde_assume(context::TestLogModifyingChildContext, right, left, vn, vi)
#@info "TestLogModifyingChildContext dot_tilde_assume!! called for $vn"
bgctw marked this conversation as resolved.
Show resolved Hide resolved
value, logp, vi = DynamicPPL.dot_tilde_assume(context.context, right, left, vn, vi)
return value, logp*context.mod, vi
end
function DynamicPPL.tilde_observe(context::TestLogModifyingChildContext, right, left, vi)
# @info "called tilde_observe TestLogModifyingChildContext for left=$left, right=$right"
bgctw marked this conversation as resolved.
Show resolved Hide resolved
logp, vi = DynamicPPL.tilde_observe(context.context, right, left, vi)
return logp*context.mod, vi
end
function DynamicPPL.dot_tilde_observe(context::TestLogModifyingChildContext, right, left, vi)
logp, vi = DynamicPPL.dot_tilde_observe(context.context, right, left, vi)
return logp*context.mod, vi
end
bgctw marked this conversation as resolved.
Show resolved Hide resolved


end
3 changes: 2 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
bgctw marked this conversation as resolved.
Show resolved Hide resolved
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Expand All @@ -26,10 +27,10 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Accessors = "0.1"
ADTypes = "0.2, 1"
AbstractMCMC = "5"
AbstractPPL = "0.8.2"
Accessors = "0.1"
Bijectors = "0.13"
Compat = "4.3.0"
Distributions = "0.25"
Expand Down
4 changes: 2 additions & 2 deletions test/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using DynamicPPL:
NodeTrait,
IsLeaf,
IsParent,
PointwiseLikelihoodContext,
PointwiseLogdensityContext,
contextual_isassumption,
ConditionContext,
hasconditioned,
Expand Down Expand Up @@ -67,7 +67,7 @@ end
SamplingContext(),
MiniBatchContext(DefaultContext(), 0.0),
PrefixContext{:x}(DefaultContext()),
PointwiseLikelihoodContext(),
PointwiseLogdensityContext(),
ConditionContext((x=1.0,)),
ConditionContext((x=1.0,), ParentContext(ConditionContext((y=2.0,)))),
ConditionContext((x=1.0,), PrefixContext{:a}(ConditionContext((var"a.y"=2.0,)))),
Expand Down
Loading