Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pointpriors #663

Merged
merged 39 commits into from
Sep 30, 2024
Merged

Pointpriors #663

merged 39 commits into from
Sep 30, 2024

Conversation

bgctw
Copy link
Contributor

@bgctw bgctw commented Sep 17, 2024

Tackles #662: Querying the log-density of components of the prior.

The implementation does not decompose the log-density of dot-tilde expressions, because a possible solution (first commit, but removed in 3rd commit again) would need to decompose dot_assume, which is not under context control. However, I do need to pass computation to child-contexts, because I want to inspect log-density transformation by child-contexts. Therefore, I called it varwise_logpriors rather than pointwise_logpriors.

In addition, I decided for a different handling of a Chains of samples compared to pointwise_likelihoods, because I did not fully comprehend its different push!! methods and different initializers for the collecting OrderedDict and what applies at which conditions. Rather, I tried separating concerns of querying densities for a single sample and applying it to a Chains object. I hope that the mutation of a pre-accocated array is ok here.

Copy link
Member

@torfjelde torfjelde left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @bgctw !

Can you clarify why you want the .~ statements to be treated as a single log-prob in your case? You mention that your motivation is tempering; it's a but unclear to me why varwise_logpriors are needed for this. And why is the Chain needed in this case? When I think of tempering in our context, I'm imaging altering the likelihood / prior weightings during sampling, not as a post-inference step.

Maybe even write a short bit of psuedo-code outlining what you want to do with this could help!

From your initial motivation in #662, I feel like we can probably find alternative approaches that might be a bit simpler:)

src/context_implementations.jl Outdated Show resolved Hide resolved
@bgctw
Copy link
Contributor Author

bgctw commented Sep 17, 2024

My goal is to modify the log-density during sampling. I imagine putting something similar to TestLogModifyingChildContext in src/test_utils.jl between the SamplingContext and the DefaultContext. For example, I want to relax the parameter priors of an ODE model during burnin or initial optimization of a Turing model, but keep the original priors on additive effects that modify parameters for simulated replicates around a population mean. However, before tackling this, I want to be able to query/see the corresponding log-densities that are used during the sampling or optimization.

Hence, I want to query the log-densities of the prior components as seen by a sampler that generated the samples in a AbstractMCMC.AbstractChains object.

The single number provided by logprior(m, vi) for a single sample is too coarse, because I want to experiment with components.

The pointwise resolution, i.e. resolving also the components of the log-density components of dot_tilde_assume such as (s[1], s[2], ...s[n]) of a s .~ Normal(...), would be nice, but it is more complex to implement together with the requirement, that those densities can be modified by a child-context. Reporting their cumulated logdensity only is an acceptable tradeoff. If this is prohibitive, the user could reformulate the .~ as an explicit loop, because then those components are resolved with the currently suggested implementation.

@coveralls
Copy link

coveralls commented Sep 17, 2024

Pull Request Test Coverage Report for Build 11091092915

Details

  • 106 of 128 (82.81%) changed or added relevant lines in 2 files are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage increased (+0.6%) to 78.105%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/test_utils.jl 13 19 68.42%
src/pointwise_logdensities.jl 93 109 85.32%
Totals Coverage Status
Change from base Build 11051106360: 0.6%
Covered Lines: 2786
Relevant Lines: 3567

💛 - Coveralls

use loop for prior in example

Unfortunately cannot make it a jldoctest, because relies on Turing for sampling
@torfjelde
Copy link
Member

Hence, I want to query the log-densities of the prior components as seen by a sampler that generated the samples in a AbstractMCMC.AbstractChains object.

Ah, gotcha; this was the aspect I was missing 👍

The pointwise resolution, i.e. resolving also the components of the log-density components of dot_tilde_assume such as (s[1], s[2], ...s[n]) of a s .~ Normal(...), would be nice, but it is more complex to implement together with the requirement, that those densities can be modified by a child-context. Reporting their cumulated logdensity only is an acceptable tradeoff. If this is prohibitive, the user could reformulate the .~ as an explicit loop, because then those components are resolved with the currently suggested implementation.

Makes sense 👍

Taking this into account, I'm wondering if maybe it would be better to just generalize the existing PointwiseLikelihoodContext that we have here

struct PointwiseLikelihoodContext{A,Ctx} <: AbstractContext
loglikelihoods::A
context::Ctx
end

We can just add a "switch" to it (or maybe just inspect the leaf context) to determine what logprobs we should keep around. AFAIK this should just require implementing the following:

  1. tilde_assume and dot_tilde_assume
  2. A quick check in (1) to determine whether we should include a variable or not.

Then we can just add alternatives to the following user-facing method

function pointwise_loglikelihoods(model::Model, chain, keytype::Type{T}=String) where {T}
# Get the data by executing the model once
vi = VarInfo(model)
context = PointwiseLikelihoodContext(OrderedDict{T,Vector{Float64}}())
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
for (sample_idx, chain_idx) in iters
# Update the values
setval!(vi, chain, sample_idx, chain_idx)
# Execute model
model(vi, context)
end
niters = size(chain, 1)
nchains = size(chain, 3)
loglikelihoods = OrderedDict(
varname => reshape(logliks, niters, nchains) for
(varname, logliks) in context.loglikelihoods
)
return loglikelihoods
end
function pointwise_loglikelihoods(model::Model, varinfo::AbstractVarInfo)
context = PointwiseLikelihoodContext(OrderedDict{VarName,Vector{Float64}}())
model(varinfo, context)
return context.loglikelihoods
end

e.g. pointwise_prior_logprobs or something.

So all in all, basically what you've already done, but just as part of the PointwiseLikelihoodContext (which we should then subsequently rename of course).

Thoughts?

@bgctw
Copy link
Contributor Author

bgctw commented Sep 18, 2024

Trying to unify those two is a good idea. In fact, I originally started exploring/modifying based on PointwiseLikelihoodContext.

However, I did not come far with this. PointwiseLikelihoodContext, resolves the dot_tilde_observe by intercepting before the agglogp!, but still can forward the density computation to the child context. I did not manage to do that with the priors. Hence, I do not know how to implement your unifying suggestion.

@bgctw
Copy link
Contributor Author

bgctw commented Sep 19, 2024

I will attempt the implementation that you suggested, assuming that components of the prior are not resolved to the same detail as the components of the likelihood.

@bgctw
Copy link
Contributor Author

bgctw commented Sep 19, 2024

I pushed a new commit that integrates pointwise_loglikelihoods and varwise_logpriors to the new function pointwise_logdensities.

The hardest part was to create a single VarName from the AbstractVector{VarName} for the case where only the summed logdensity for several prior components in dot_tilde_assume is to be recorded. varwise_logpriors simply used a Symbol, but the generalized pointwise_loglikelihoods requires a single VarName. The implementation at src/pointwise_logdensities.jl around line 153 has to assume several details of the Optics in the given VarNames.

Another issue, is that now pointwise_loglikelihoods provides information on all variables, although the logdensity of the priors is zero. Hence, one cannot check on empty Result (around line 29) to catch the case of literal observations. How can I ask the model or VarInfo which variables are priors and which are observations?

I could not yet recreate julia-repl block in the documentation of the function, because current Turing, which is required for sampling in the docstring, is not compatible with current DynamicPPL.

@torfjelde
Copy link
Member

Lovely @bgctw ! I'll a proper have a look at it a bit later today:)

@bgctw
Copy link
Contributor Author

bgctw commented Sep 19, 2024

In order for the user to select relevant information and for saving processing time, it could be helpful to have two keyword arguments with defaults: report_logpriors=true and report_loglikelihoods=true. If the corresponding flag is false, log-densities would not be calculated (not passed to child context) and would not appear in the results. The report_logpriors could be set to false in the forwarding of pointwise_loglikelihoods which would also allow to check on empty results in tests again.

Would these be reasonable?

by forwarding dot_tilde_assume to tilde_assume
@bgctw
Copy link
Contributor Author

bgctw commented Sep 21, 2024

I found a way to record single log-density prior components in dot_tilde_assume: I forward each variable to tilde_assume but currently take the value and VarInfo from dot_tilde_assume applied with the child context. This assumes that VarInfo is not mutated in tilde_assume.
There is still the case of tilde assignment to single multivariate distribution, where there is only a single log-density for a combination of VarNames, e.g. s .~ product_distribution([InverseGamma(2, 3) for _ in 1:d]) in TestUtils.demo_dot_assume_matrix_dot_observe_matrix. Hence, I do not get around combining indices of VarNames.

The forwarding of dot_tilde_assume to multiple tilde_assume works for the PointwiseLogdensityContext case. Is there potential to apply it also at other places to simplify DynamicPPL or is the separate dispatch mechanism important?

@bgctw
Copy link
Contributor Author

bgctw commented Sep 22, 2024

Forwarding to tilde_assume now also works for the case of tilde assignment to a single multivariate distribution, e.g. s .~ product_distribution([InverseGamma(2, 3) for _ in 1:d]) in TestUtils.demo_dot_assume_matrix_dot_observe_matrix. No need any more for combining indices of VarNames.

Copy link
Member

@torfjelde torfjelde left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I was working on some changes in your branch and wanted to make a PR to yours, but doesn't seem like that works due to you being on a fork o.O (or maybe I'm just being stupid).

So instead I made a new PR over at #669 . You can see the diff from yours to mine that I added here: https://github.com/TuringLang/DynamicPPL.jl/pull/669/files/5842656154a5b2f9a0377c45a4d4438933971a11..8bd2085098208fc58d1e33bbe48ec56e7efcd691

EDIT: Did this because it was a bit easier to demonstrate what I had in mind rather than explaining it through a bunch of comments

src/deprecated.jl Outdated Show resolved Hide resolved
src/test_utils.jl Outdated Show resolved Hide resolved
src/test_utils.jl Outdated Show resolved Hide resolved
@bgctw
Copy link
Contributor Author

bgctw commented Sep 23, 2024

So instead I made a new PR over at #669 . You can see the diff from yours to mine that I added here: https://github.com/TuringLang/DynamicPPL.jl/pull/669/files/5842656154a5b2f9a0377c45a4d4438933971a11..8bd2085098208fc58d1e33bbe48ec56e7efcd691

I see

  • the different undeprecated subtypes of pointwise_logdensities
  • the _istcontext to suppress entire prior or likelihood recording
  • I do not understand some of the code (some dubbed Hack)

Your PR is based on an older version of this PR. What is the way forward now? Should I try to merge your changes to this PR? Or should I try to implement my subsequent changes to your PR?
I am not as experienced with forks and contributing to pull-requests. How can I make my fork writeable/pushable to you?

and avoid recording likelihoods when invoked with leaf-Likelihood context
bgctw first forwared dot_tilde_assume to get a correct vi
and then recomputed it for recording component prior densities.

Replaced this by the Hack of torfjelde that completely drops vi and recombines the value, so that assume is called only once for each varName,
pointwise_prior_logdensities
int api.md docu
@bgctw
Copy link
Contributor Author

bgctw commented Sep 24, 2024

I transferred the developments in #669 to this PR. The solution with dropping the updated VarInfos and only relying on the non-dot version is more efficient than my version of recomputing the VarInfo from the original dot-version of tilde_assume. Although, the "flatten and recombine" hack (line 230) for the Multivariate Distribution is hard to comprehend, and one needs to remember that when modifying dot_tilde_assume, one now needs to also consistently adapt _point_tilde_assume.

@torfjelde
Copy link
Member

Then, samplers and other code would only deal with the non-dot versions. Maybe this makes a few performance and other tweaks impossible, such as samplers hooking into the dot-dispatch. But this would be more gentle compared to deprecate the support for .~ entirely.

Regarding this, I think it's worth preserving the .~ in the same sense as .= works in Julia:) It does both have performance implications and is a nice semantic for users. But yeah, agree that it's a bit annoying.

@bgctw
Copy link
Contributor Author

bgctw commented Sep 26, 2024

The suggestions from code review introduced some errors in the tests, which I tried to fix. However, I did not succeed for the "pointwise_logdensities chain" testset. Could you, please, have another look, if this is a problem of the test setup or the tested functionality. Your test is more strict, because it compares to logjoint_true(model, val...) rather than pointwise_logdensities(model, VarInfo).

@torfjelde
Copy link
Member

Ah yeah, it's failing because DynamicPPL.TestUtils.varnames(model) only returns the variables that are considered random, i.e. excluding the observations. Lemme have a go at fixing this 👍

@torfjelde
Copy link
Member

torfjelde commented Sep 26, 2024

Fixed the test @bgctw :)

.gitignore Outdated Show resolved Hide resolved
Copy link

codecov bot commented Sep 26, 2024

Codecov Report

Attention: Patch coverage is 82.81250% with 22 lines in your changes missing coverage. Please review.

Project coverage is 77.66%. Comparing base (067ac4c) to head (cff0941).
Report is 1 commits behind head on master.

Files with missing lines Patch % Lines
src/pointwise_logdensities.jl 85.32% 16 Missing ⚠️
src/test_utils.jl 68.42% 6 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master     #663      +/-   ##
==========================================
+ Coverage   75.93%   77.66%   +1.73%     
==========================================
  Files          29       29              
  Lines        3519     3587      +68     
==========================================
+ Hits         2672     2786     +114     
+ Misses        847      801      -46     
Flag Coverage Δ
77.66% <82.81%> (+1.73%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@bgctw
Copy link
Contributor Author

bgctw commented Sep 27, 2024

Thanks @torfjelde for patiently guiding me through this process.

@bgctw
Copy link
Contributor Author

bgctw commented Sep 27, 2024

Another maybe:

I find it more convenient to work with the results of the pointwise functions applied to AbstractChains as an AbstractChains again, rather than the OrderedDict{String, Matrix}.
What is a good place to support this conversion? A function in the MCMCChains extension such as:

function as_chains(lds_pointwise)
     Chains(stack(values(lds_pointwise); dims=2), collect(keys(lds_pointwise)))
end
chn = as_chains(logjoints_pointwise); # from @testset "pointwise_logdensities chain"
names(chn)
get(chn, :x)[1] == logjoints_pointwise["x"]

One could even think of letting the pointwise_logdensities(..., ::AbstractChains) routinely return a Chains object.
This could be achieved by

  • renaming the pointwise_logdensities(..., ::AbstractChains) to pointwise_logdensities_dict,
  • implementing it in the DynamicPPLMCMCChains extension by converting the result of pointwise_logdensities_dict.

Since this would break the current interface of pointwise_logdensities, I make this suggestion here instead of an own issue, before pointwise_logdensities (and siblings) go to the master branch.

@torfjelde
Copy link
Member

I just pushed a final change to the docstring of pointwise_logdensities (which was out of date) + made it a doctest. Once tests pass now, I'll add it to the merge queue:)

Thanks @torfjelde for patiently guiding me through this process.

Of course! Glad to hear you found it useful:)

I find it more convenient to work with the results of the pointwise functions applied to AbstractChains as an AbstractChains again, rather than the OrderedDict{String, Matrix}.

Hmm, I'm a bit uncertain about this. I do see your reasoning that it might be beneficial, but I think, at least at the moment, I'm reluctant to make this part of DynamicPPL 😕 Generally, we adopt features in DPPL once we feel like there's sufficient need for it; atm, I think most people using pointwise_logdensities (me being amongst them), would rather work with an OrderedDict instead of Chains 😕

But how about you convert that comment into an issue so a) we can keep track of the desired feature and see if there are other people who share the interest in this, and b) so that the current impl you are using can also be discovered more easily by others?:)

@bgctw
Copy link
Contributor Author

bgctw commented Sep 27, 2024

But how about you convert that comment into an issue so a) we can keep track of the desired feature and see if there are other people who share the interest in this, and b) so that the current impl you are using can also be discovered more easily by others?:)

I will do that after its available on master.

@torfjelde
Copy link
Member

Added it to the merge queue; thank you @bgctw !

@torfjelde torfjelde added this pull request to the merge queue Sep 30, 2024
Merged via the queue into TuringLang:master with commit 8c3aa44 Sep 30, 2024
12 of 13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants