Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/jaxtyping #746

Merged
merged 26 commits into from
Sep 4, 2024
Merged

Feat/jaxtyping #746

merged 26 commits into from
Sep 4, 2024

Conversation

db091756
Copy link
Member

@db091756 db091756 commented Aug 22, 2024

PR Type

  • Refactoring (no functional changes)

Description

I do not touch the typing of the tests in this PR.

Implement Jaxtyping throughout. There should be minimal to no use of ArrayLike or Array without shaping unless the shape is unknown.

There should also be very few Pyright complaints now, except those I couldn't fix with the code as is.

How Has This Been Tested?

Does this PR introduce a breaking change?

(Write your answer here.)

Screenshots

(Write your answer here.)

Checklist before requesting a review

  • I have made sure that my PR is not a duplicate.
  • My code follows the style guidelines of this project.
  • I have commented my code, particularly in hard-to-understand areas.
  • I have performed a self-review of my code.
  • I have made corresponding changes to the documentation.
  • My changes generate no new warnings.
  • I have added tests that prove my fix is effective or that my feature works.
  • New and existing unit tests pass locally with my changes.
  • Any dependent changes have been merged and published in downstream modules.

@db091756 db091756 linked an issue Aug 22, 2024 that may be closed by this pull request

This comment was marked as outdated.

@db091756 db091756 marked this pull request as ready for review August 23, 2024 14:12

This comment was marked as outdated.

@db091756
Copy link
Member Author

db091756 commented Aug 23, 2024

Performance review

Statistically significant changes

  • basic_stein:

    • OLD: compilation 3.193 units ± 0.179 units; execution 2.867 units ± 0.04594 units
    • NEW: compilation 3.289 units ± 0.2395 units; execution 6.56 units ± 0.03179 units
    • Significant increase in execution time (128.85%, p=6.059e-29)

Normalisation values for new data: Compilation: 1 unit = 530.66 ms Execution: 1 unit = 681.94 ms

Significant slowdown in the Stein Thinning method, unsure why as I don't believe I have changed anything functionally relevant here

Copy link
Contributor

Performance review

Statistically significant changes

  • basic_stein:
    • OLD: compilation 3.193 units ± 0.179 units; execution 2.867 units ± 0.04594 units
    • NEW: compilation 3.335 units ± 0.1961 units; execution 6.451 units ± 0.04194 units
    • Significant increase in execution time (125.04%, p=1.213e-30)

Normalisation values for new data:
Compilation: 1 unit = 514.29 ms
Execution: 1 unit = 678.56 ms

@rg936672 rg936672 self-requested a review August 28, 2024 13:36
Copy link
Contributor

@rg936672 rg936672 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an impressive effort! There are a couple of recurring issues:

  • Nested unions (i.e. Union[A, B, Union[C, D]]) are redundant - they should be flattened out to e.g. Union[A, B, C, D]
  • Any plan Array type hint that cannot have its shape annotated should have a comment noting why the shape is not possible to statically determine

.cspell/custom_misc.txt Outdated Show resolved Hide resolved
CHANGELOG.md Outdated Show resolved Hide resolved
coreax/approximation.py Outdated Show resolved Hide resolved
coreax/approximation.py Outdated Show resolved Hide resolved
coreax/coreset.py Show resolved Hide resolved
coreax/score_matching.py Outdated Show resolved Hide resolved
coreax/score_matching.py Outdated Show resolved Hide resolved
coreax/solvers/coresubset.py Outdated Show resolved Hide resolved
coreax/util.py Outdated Show resolved Hide resolved
coreax/util.py Outdated Show resolved Hide resolved
CHANGELOG.md Outdated Show resolved Hide resolved
@rg936672
Copy link
Contributor

For the remaining pyright complaints (only those in the coreax module, not tests):

  • The pyright complaints in coreax.kernel can be fixed by replacing each from ... import XYZKernel with from ... import _XYZKernel; the issue seems to just be that the name is reused
  • The complaint in Coresubset.__init__ can be fixed as described by inheriting from Coreset[Data] rather than Coreset[_Data]
  • The complaint in networks.create_train_state can be fixed by replacing random_key with jax.random.key(random_key) (or an equivalently aliased import).

With these, the only remaining pyright complaints will be within the tests!

@db091756
Copy link
Member Author

db091756 commented Aug 29, 2024

For the remaining pyright complaints (only those in the coreax module, not tests):

  • The pyright complaints in coreax.kernel can be fixed by replacing each from ... import XYZKernel with from ... import _XYZKernel; the issue seems to just be that the name is reused
  • The complaint in Coresubset.__init__ can be fixed as described by inheriting from Coreset[Data] rather than Coreset[_Data]
  • The complaint in networks.create_train_state can be fixed by replacing random_key with jax.random.key(random_key) (or an equivalently aliased import).

With these, the only remaining pyright complaints will be within the tests!

Fixed the last one. Put some questions about the second one in the original comment. The first one I think I am misunderstanding as it seems like this would create import errors. Because the _XYZKernel doesn't exist? I assume you mean from ... import XYZKernel as _XYZKernel, I have implemented that change!

@db091756
Copy link
Member Author

db091756 commented Aug 29, 2024

Hopefully responded to all your comments! I also removed some of the overloads that I think are not necessary

Copy link
Contributor

Performance review

Statistically significant changes

  • basic_stein:
    • OLD: compilation 3.285 units ± 0.1923 units; execution 2.931 units ± 0.01816 units
    • NEW: compilation 3.299 units ± 0.259 units; execution 6.4 units ± 0.048 units
    • Significant increase in execution time (118.36%, p=4.212e-22)

Normalisation values for new data:
Compilation: 1 unit = 522.08 ms
Execution: 1 unit = 684.06 ms

@db091756
Copy link
Member Author

Adding more overloads that I missed

Copy link
Contributor

Performance review

Statistically significant changes

  • basic_stein:
    • OLD: compilation 3.285 units ± 0.1923 units; execution 2.931 units ± 0.01816 units
    • NEW: compilation 3.377 units ± 0.2713 units; execution 6.349 units ± 0.0442 units
    • Significant increase in execution time (116.63%, p=4.441e-23)

Normalisation values for new data:
Compilation: 1 unit = 521.73 ms
Execution: 1 unit = 682.68 ms

Copy link
Contributor

Performance review

Statistically significant changes

  • basic_stein:
    • OLD: compilation 3.285 units ± 0.1923 units; execution 2.931 units ± 0.01816 units
    • NEW: compilation 3.321 units ± 0.2564 units; execution 6.264 units ± 0.05966 units
    • Significant increase in execution time (113.72%, p=1.303e-19)

Normalisation values for new data:
Compilation: 1 unit = 512.48 ms
Execution: 1 unit = 686.93 ms

Copy link
Contributor

@rg936672 rg936672 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good! A few issues remain:

  • A couple of unresolved comments from the first review
  • Docs build is failing with py:class reference target not found: 'n *d' and 'n *p'
  • Tests are now failing!
  • There are a couple of new Pyright complaints that have popped up (only 8 errors, though). Happy for these to be deferred to a ticket where we include a formal type checker (Include PyRight in CI #427 or Add a runtime type checker #755) if needed though

coreax/kernel.py Outdated Show resolved Hide resolved
@rg936672
Copy link
Contributor

Since other PRs (for example #758) aren't introducing any performance changes, the degradation in Stein performance is an issue created by this PR. This bears investigating!

@rg936672
Copy link
Contributor

Adding the lines

    ("py:class", "'n d'"),
    ("py:class", "'n p'"),

to the Sphinx nitpick_ignore seems to make the docs build pass.

Copy link
Contributor

@rg936672 rg936672 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good! I've found the source of the performance issues, and the docs build can be fixed by #746 (comment). There's still a couple of Pyright complaints, but I'm happy for these to be left for #427 if needed.

Once the docs build is passing, and the line causing the performance issues is either reverted or given an appropriate TODO and issue, I'm happy for this to be merged.

coreax/score_matching.py Outdated Show resolved Hide resolved
@db091756
Copy link
Member Author

Adding the lines

    ("py:class", "'n d'"),
    ("py:class", "'n p'"),

to the Sphinx nitpick_ignore seems to make the docs build pass.

Ah I was doing
("py:class", "n d"),
("py:class", "n p"),

thanks!

@db091756
Copy link
Member Author

Looking good! I've found the source of the performance issues, and the docs build can be fixed by #746 (comment). There's still a couple of Pyright complaints, but I'm happy for these to be left for #427 if needed.

Once the docs build is passing, and the line causing the performance issues is either reverted or given an appropriate TODO and issue, I'm happy for this to be merged.

I think we leave the remaining Pyright complaints for Beartype to help us with

Copy link
Contributor

Performance review

Statistically significant changes

  • basic_stein:
    • OLD: compilation 3.138 units ± 0.2318 units; execution 2.917 units ± 0.03883 units
    • NEW: compilation 3.321 units ± 0.2983 units; execution 6.392 units ± 0.05105 units
    • Significant increase in execution time (119.12%, p=1.106e-28)

Normalisation values for new data:
Compilation: 1 unit = 528.56 ms
Execution: 1 unit = 682.89 ms

Copy link
Contributor

Performance review

Statistically significant changes

  • basic_stein:
    • OLD: compilation 3.138 units ± 0.2318 units; execution 2.917 units ± 0.03883 units
    • NEW: compilation 3.214 units ± 0.2804 units; execution 6.386 units ± 0.05752 units
    • Significant increase in execution time (118.94%, p=1.077e-26)

Normalisation values for new data:
Compilation: 1 unit = 514.05 ms
Execution: 1 unit = 680.67 ms

Copy link
Contributor

Performance review

Statistically significant changes

  • basic_stein:
    • OLD: compilation 3.138 units ± 0.2318 units; execution 2.917 units ± 0.03883 units
    • NEW: compilation 3.237 units ± 0.2076 units; execution 6.3 units ± 0.01836 units
    • Significant increase in execution time (115.96%, p=4.99e-25)

Normalisation values for new data:
Compilation: 1 unit = 527.58 ms
Execution: 1 unit = 679.46 ms

Copy link
Contributor

github-actions bot commented Sep 3, 2024

Performance review

No significant changes to performance.

Copy link
Contributor

github-actions bot commented Sep 3, 2024

Performance review

No significant changes to performance.

Copy link
Contributor

@rg936672 rg936672 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some final changes to the type hint for _atleast_2d_consistent, and then I think we're done!

coreax/data.py Outdated Show resolved Hide resolved
Copy link
Contributor

github-actions bot commented Sep 4, 2024

Performance review

No significant changes to performance.

2 similar comments
Copy link
Contributor

github-actions bot commented Sep 4, 2024

Performance review

No significant changes to performance.

Copy link
Contributor

github-actions bot commented Sep 4, 2024

Performance review

No significant changes to performance.

@db091756 db091756 removed a link to an issue Sep 4, 2024
@db091756 db091756 linked an issue Sep 4, 2024 that may be closed by this pull request
1 task
Copy link
Contributor

@rg936672 rg936672 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Nicely done.

@rg936672 rg936672 merged commit cc4722e into main Sep 4, 2024
23 checks passed
@rg936672 rg936672 deleted the feat/jaxtyping branch September 4, 2024 11:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Annotate array dimensions using jaxtyping
2 participants