Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/jaxtyping #746

Merged
merged 26 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
e23b174
feat: continue adding jaxtyping
db091756 Aug 22, 2024
a4c4b99
remove files pulled from different PR
db091756 Aug 22, 2024
f9871da
feat: finished typing pass
db091756 Aug 23, 2024
3dc5d0b
Merge branch 'main' into feat/jaxtyping
db091756 Aug 23, 2024
a07db43
fix: add back lost ruff disable
db091756 Aug 23, 2024
2c5eb04
feat: another jaxtyping pass
db091756 Aug 23, 2024
8b91eb9
fix: doc build errors
db091756 Aug 23, 2024
aa188a6
feat: added back jit
db091756 Aug 23, 2024
abbd5a2
responding to reviewer comments
db091756 Aug 29, 2024
d574777
feat: added further jaxtyping
db091756 Aug 29, 2024
16a4bc4
fix: replaced n with m in pairwise
db091756 Aug 29, 2024
e2a1cf1
responding to reviewer comments
db091756 Aug 29, 2024
751406e
responding to reviewer comments
db091756 Aug 29, 2024
82271f9
feat: added overloads the compute mean method
db091756 Aug 29, 2024
b897254
Merge branch 'main' into feat/jaxtyping
db091756 Aug 29, 2024
b94f0c4
feat: added ruff rules
db091756 Aug 30, 2024
419a639
feat: added ... to the exclude_also
db091756 Aug 30, 2024
46bb0ea
responding to reviewer comments
db091756 Aug 30, 2024
64e7215
fix: fix doc build error
db091756 Aug 30, 2024
4041e0f
fix: we do not need to do the as_array conversion when the function i…
db091756 Aug 30, 2024
713ac4c
fix: fixed missed typing
db091756 Aug 30, 2024
dce12f7
fix: reverted dimensional control in score_matching.py, improved over…
db091756 Sep 3, 2024
46f2c24
docs: changed 'd' to 'n' in overloads
db091756 Sep 3, 2024
860ad4b
Merge branch 'main' into feat/jaxtyping
db091756 Sep 4, 2024
bf51a87
docs: improved overloads
db091756 Sep 4, 2024
c302936
tests: add back old tests
db091756 Sep 4, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .cspell/custom_misc.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ kernelised
kernelized
KSD
linewidth
Matérn
Matern
ml.p3.8xlarge
ndmin
parsable
Expand Down
1 change: 1 addition & 0 deletions .cspell/people.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Jiaxin
Jitkrittum
Kanagawa
Martinsson
Matérn
Motonobu
Nystr
Nystrom
Expand Down
124 changes: 88 additions & 36 deletions coreax/approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@
import jax.numpy as jnp
import jax.random as jr
from jax import Array
from jax.typing import ArrayLike
from jaxtyping import Shaped
from typing_extensions import TYPE_CHECKING, Literal, override

from coreax.data import Data
from coreax.data import Data, _atleast_2d_consistent
from coreax.kernels import UniCompositeKernel
from coreax.util import KeyArrayLike

Expand Down Expand Up @@ -77,26 +77,28 @@ def _random_indices(

def _random_least_squares(
key: KeyArrayLike,
data: Array,
features: Array,
data: Shaped[Array, " n p"],
features: Shaped[Array, " n n"],
num_indices: int,
target_map: Callable[[Array], Array] = lambda x: x,
) -> Array:
target_map: Callable[[Shaped[Array, " n p"]], Shaped[Array, " n p"]] = lambda x: x,
) -> Shaped[Array, " n p"]:
r"""
Solve the least-square problem on a random subset of the system.

A linear system :math:`Ax = b`, solved via least-squares as :math:`x = A^+ b`, can
be approximated by random least-square as `x \approx \hat{x} = \hat{A}^+ \hat{b}`,
where :math:`\hat{A} = A_i\ \text{and}\ \hat{b} = b_i\, \forall i \in I]`. `I` is a
random subset of indices for the original system of equations.
A linear system :math:`AX = B`, solved via least-squares as :math:`X = A^+ B`, can
be approximated by random least-square as `X \approx \hat{X} = \hat{A}^+ \hat{B}`,
where
:math:`\hat{A} = A_{i\cdot}\ \text{and}\ \hat{B} = B_{i\cdot}\, \forall i \in I]`.
:math:`I` is a random subset of indices for the original system of equations.

:param key: RNG key for seeding the random selection
:param data: The data :math:`z`; yields :math:`b` when pushed through the target map
:param features: The feature matrix :math:`A`
:param data: The data :math:`Z \in \mathbb{R}^{n \times p}`; yields
:math:`B \in \mathbb{R}^{n \times p}` when pushed through the target map
:param features: The feature matrix :math:`A \in \mathbb{R}^{n \times n}`
:param num_indices: The size of the random subset of indices :math:`I`
:param target_map: The target map :math:`\phi` which defines :math:`b := \phi(z)`,
where :math:`z` is the input ``data``
:return: The push-forward of the approximate solution :math:`A\hat{x}`
:return: The push-forward of the approximate solution :math:`A\hat{X}`
"""
num_data_points = len(data)
train_idx = _random_indices(key, num_data_points, num_indices, mode="train")
Expand All @@ -120,19 +122,19 @@ class ApproximateKernel(UniCompositeKernel):
"""

@override
def compute_elementwise(self, x: ArrayLike, y: ArrayLike) -> Array:
def compute_elementwise(self, x, y):
return self.base_kernel.compute_elementwise(x, y)

@override
def grad_x_elementwise(self, x: ArrayLike, y: ArrayLike) -> Array:
def grad_x_elementwise(self, x, y):
return self.base_kernel.grad_x_elementwise(x, y)

@override
def grad_y_elementwise(self, x: ArrayLike, y: ArrayLike) -> Array:
def grad_y_elementwise(self, x, y):
return self.base_kernel.grad_y_elementwise(x, y)

@override
def divergence_x_grad_y_elementwise(self, x: ArrayLike, y: ArrayLike) -> Array:
def divergence_x_grad_y_elementwise(self, x, y):
return self.base_kernel.divergence_x_grad_y_elementwise(x, y)


Expand Down Expand Up @@ -173,7 +175,18 @@ class MonteCarloApproximateKernel(RandomRegressionKernel):
:param num_train_points: Number of training points used to fit kernel regression
"""

def gramian_row_mean(self, x: Union[ArrayLike, Data], **kwargs) -> Array:
def gramian_row_mean(
self,
x: Union[
Shaped[Array, " n d"],
Shaped[Array, " d"],
Shaped[Array, ""],
float,
int,
Data,
],
**kwargs,
) -> Shaped[Array, " n"]:
r"""
Approximate the Gramian row-mean by Monte-Carlo sampling.

Expand All @@ -184,17 +197,22 @@ def gramian_row_mean(self, x: Union[ArrayLike, Data], **kwargs) -> Array:
:return: Approximation of the base kernel's Gramian row-mean
"""
del kwargs
data = jnp.atleast_2d(jnp.asarray(x))
num_data_points = len(data)
# This method does not support weighted computation of the mean, therefore
# we need to handle the case where `x` is passed as a `Data` instance
if isinstance(x, Data):
x = x.data
x = _atleast_2d_consistent(x)

num_data_points = len(x)
key = self.random_key
features_idx = _random_indices(key, num_data_points, self.num_kernel_points - 1)
features = self.base_kernel.compute(data, data[features_idx])
features = self.base_kernel.compute(x, x[features_idx])
return _random_least_squares(
key,
data,
x,
features,
self.num_train_points,
partial(self.base_kernel.compute_mean, data, axis=0),
partial(self.base_kernel.compute_mean, x, axis=0),
)


Expand All @@ -212,7 +230,18 @@ class ANNchorApproximateKernel(RandomRegressionKernel):
:param num_train_points: Number of training points used to fit kernel regression
"""

def gramian_row_mean(self, x: Union[ArrayLike, Data], **kwargs) -> Array:
def gramian_row_mean(
self,
x: Union[
Shaped[Array, " n d"],
Shaped[Array, " d"],
Shaped[Array, ""],
float,
int,
Data,
],
**kwargs,
) -> Shaped[Array, " n"]:
r"""
Approximate the Gramian row-mean by random regression on ANNchor points.

Expand All @@ -224,12 +253,19 @@ def gramian_row_mean(self, x: Union[ArrayLike, Data], **kwargs) -> Array:
:return: Approximation of the base kernel's Gramian row-mean
"""
del kwargs
data = jnp.atleast_2d(jnp.asarray(x))
num_data_points = len(data)
# This method does not support weighted computation of the mean, therefore
# we need to handle the case where `x` is passed as a `Data` instance
if isinstance(x, Data):
x = x.data
x = _atleast_2d_consistent(x)

num_data_points = len(x)
features = jnp.zeros((num_data_points, self.num_kernel_points))
features = features.at[:, 0].set(self.base_kernel.compute(data, data[0])[:, 0])
features = features.at[:, 0].set(self.base_kernel.compute(x, x[0])[:, 0])

def _annchor_body(idx: int, _features: Array) -> Array:
def _annchor_body(
idx: int, _features: Shaped[Array, " n num_kernel_points"]
) -> Shaped[Array, " n num_kernel_points"]:
r"""
Execute main loop of the ANNchor construction.

Expand All @@ -239,17 +275,17 @@ def _annchor_body(idx: int, _features: Array) -> Array:
"""
max_entry = _features.max(axis=1).argmin()
_features = _features.at[:, idx].set(
self.base_kernel.compute(data, data[max_entry])[:, 0]
self.base_kernel.compute(x, x[max_entry])[:, 0]
)
return _features

features = jax.lax.fori_loop(1, self.num_kernel_points, _annchor_body, features)
return _random_least_squares(
self.random_key,
data,
x,
features,
self.num_train_points,
partial(self.base_kernel.compute_mean, data, axis=0),
partial(self.base_kernel.compute_mean, x, axis=0),
)


Expand All @@ -267,7 +303,18 @@ class NystromApproximateKernel(RandomRegressionKernel):
:param num_train_points: Number of training points used to fit kernel regression
"""

def gramian_row_mean(self, x: Union[ArrayLike, Data], **kwargs) -> Array:
def gramian_row_mean(
self,
x: Union[
Shaped[Array, " n d"],
Shaped[Array, " d"],
Shaped[Array, ""],
float,
int,
Data,
],
**kwargs,
) -> Shaped[Array, " n"]:
r"""
Approximate the Gramian row-mean by Nystrom approximation.

Expand All @@ -280,15 +327,20 @@ def gramian_row_mean(self, x: Union[ArrayLike, Data], **kwargs) -> Array:
:return: Approximation of the base kernel's Gramian row-mean
"""
del kwargs
data = jnp.atleast_2d(jnp.asarray(x))
num_data_points = len(data)
# This method does not support weighted computation of the mean, therefore
# we need to handle the case where `x` is passed as a `Data` instance
if isinstance(x, Data):
x = x.data
x = _atleast_2d_consistent(x)

num_data_points = len(x)
feature_idx = _random_indices(
self.random_key, num_data_points, self.num_kernel_points
)
features = self.base_kernel.compute(data, data[feature_idx])
features = self.base_kernel.compute(x, x[feature_idx])
return _random_least_squares(
self.random_key, # intentional key reuse to ensure train_idx = feature_idx
data,
x,
features,
self.num_train_points,
self.base_kernel.gramian_row_mean,
Expand Down
53 changes: 31 additions & 22 deletions coreax/coreset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from jaxtyping import Array, Shaped
from typing_extensions import Self

from coreax.data import Data, SupervisedData, as_data
from coreax.data import Data, as_data, as_supervised_data
from coreax.metrics import Metric
from coreax.weights import WeightsOptimiser

Expand Down Expand Up @@ -73,9 +73,25 @@ class Coreset(eqx.Module, Generic[_Data]):
:param pre_coreset_data: The dataset :math:`X` used to construct the coreset.
"""

nodes: Data = eqx.field(converter=as_data)
nodes: _Data
pre_coreset_data: _Data

def __init__(self, nodes: _Data, pre_coreset_data: _Data):
"""Handle type conversion of ``nodes`` and ``pre_coreset_data``."""
if isinstance(nodes, Array):
self.nodes = as_data(nodes)
elif isinstance(nodes, tuple):
self.nodes = as_supervised_data(nodes)
else:
self.nodes = nodes

if isinstance(pre_coreset_data, Array):
self.pre_coreset_data = as_data(pre_coreset_data)
elif isinstance(pre_coreset_data, tuple):
self.pre_coreset_data = as_supervised_data(pre_coreset_data)
else:
self.pre_coreset_data = pre_coreset_data

def __check_init__(self):
"""Check that coreset has fewer 'nodes' than the 'pre_coreset_data'."""
if len(self.nodes) > len(self.pre_coreset_data):
Expand All @@ -89,21 +105,23 @@ def __len__(self):
return len(self.nodes)

@property
def coreset(self) -> Data:
def coreset(self) -> _Data:
"""Materialised coreset."""
return self.nodes

def solve_weights(self, solver: WeightsOptimiser, **solver_kwargs) -> Self:
def solve_weights(self, solver: WeightsOptimiser[_Data], **solver_kwargs) -> Self:
"""Return a copy of 'self' with weights solved by 'solver'."""
weights = solver.solve(self.pre_coreset_data, self.coreset, **solver_kwargs)
return eqx.tree_at(lambda x: x.nodes.weights, self, weights)

def compute_metric(self, metric: Metric, **metric_kwargs) -> Array:
def compute_metric(
self, metric: Metric[_Data], **metric_kwargs
) -> Shaped[Array, ""]:
"""Return metric-distance between `self.pre_coreset_data` and `self.coreset`."""
return metric.compute(self.pre_coreset_data, self.coreset, **metric_kwargs)


class Coresubset(Coreset[_Data], Generic[_Data]):
class Coresubset(Coreset[Data], Generic[_Data]):
r"""
Data structure for representing a coresubset.

Expand Down Expand Up @@ -131,26 +149,17 @@ class Coresubset(Coreset[_Data], Generic[_Data]):
:param pre_coreset_data: The dataset :math:`X` used to construct the coreset.
"""

# Incompatibility between Pylint and eqx.field. Pyright handles this correctly.
# pylint: disable=no-member
def __init__(self, nodes: Data, pre_coreset_data: _Data):
"""Handle typing of ``nodes`` being a `Data` instance."""
super().__init__(nodes, pre_coreset_data)
rg936672 marked this conversation as resolved.
Show resolved Hide resolved

@property
def coreset(self) -> Data:
def coreset(self) -> _Data:
"""Materialise the coresubset from the indices and original data."""
coreset_data = self.pre_coreset_data.data[self.unweighted_indices]
if isinstance(self.pre_coreset_data, SupervisedData):
coreset_supervision = self.pre_coreset_data.supervision[
self.unweighted_indices
]
return SupervisedData(
data=coreset_data,
supervision=coreset_supervision,
weights=self.nodes.weights,
)
return Data(data=coreset_data, weights=self.nodes.weights)
rg936672 marked this conversation as resolved.
Show resolved Hide resolved
coreset_data = self.pre_coreset_data[self.unweighted_indices]
return eqx.tree_at(lambda x: x.weights, coreset_data, self.nodes.weights)

@property
def unweighted_indices(self) -> Shaped[Array, " n"]:
"""Unweighted Coresubset indices - attribute access helper."""
return jnp.squeeze(self.nodes.data)

# pylint: enable=no-member
Loading