Skip to content

Commit

Permalink
add squeeze and unsqueeze (#235)
Browse files Browse the repository at this point in the history
* add addInnerDim and addOuterDim

* rework with numir like API
  • Loading branch information
9il authored Apr 19, 2020
1 parent 60c330e commit cdc31b8
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 4 deletions.
1 change: 0 additions & 1 deletion source/mir/algorithm/iteration.d
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ import std.traits;

@optmath:


/+
Bitslice representation for accelerated bitwise algorithm.
1-dimensional contiguousitslice can be split into three chunks: head bits, body chunks, and tail bits.
Expand Down
2 changes: 2 additions & 0 deletions source/mir/ndslice/package.d
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,13 @@ $(TR $(TDNW $(SUBMODULE topology) $(BR)
$(SUBREF topology, ReshapeError)
$(SUBREF topology, retro)
$(SUBREF topology, slide)
$(SUBREF topology, squeeze)
$(SUBREF topology, stairs)
$(SUBREF topology, stride)
$(SUBREF topology, subSlices)
$(SUBREF topology, triplets)
$(SUBREF topology, universal)
$(SUBREF topology, unsqueeze)
$(SUBREF topology, unzip)
$(SUBREF topology, windows)
$(SUBREF topology, zip)
Expand Down
173 changes: 170 additions & 3 deletions source/mir/ndslice/topology.d
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ $(TR $(TH Function Name) $(TH Description))
$(T2 blocks, n-dimensional slice composed of n-dimensional non-overlapping blocks. If the slice has two dimensions, it is a block matrix.)
$(T2 diagonal, 1-dimensional slice composed of diagonal elements)
$(T2 reshape, New slice with changed dimensions for the same data)
$(T2 reshape, New slice view with changed dimensions)
$(T2 squeeze, New slice view of an n-dimensional slice with dimension removed)
$(T2 unsqueeze, New slice view of an n-dimensional slice with a dimension added)
$(T2 windows, n-dimensional slice of n-dimensional overlapping windows. If the slice has two dimensions, it is a sliding window.)
)
Expand Down Expand Up @@ -96,7 +98,7 @@ $(SUBREF slice, Slice.shape), and $(SUBREF slice, Slice.elementCount).
License: $(HTTP boost.org/LICENSE_1_0.txt, Boost License 1.0).
Copyright: Copyright © 2016-, Ilya Yaroshenko
Authors: Ilya Yaroshenko
Authors: Ilya Yaroshenko, Shigeki Karita (original numir code)
Sponsors: Part of this work has been sponsored by $(LINK2 http://symmetryinvestments.com, Symmetry Investments) and Kaleidic Associates.
Expand Down Expand Up @@ -3918,7 +3920,7 @@ template byDim(Dimensions...)
n-dimensional slice ipacked to allow iteration by dimension
+/
@optmath auto byDim(Iterator, size_t N, SliceKind kind)
(Slice!(Iterator, N, kind) slice)
(Slice!(Iterator, N, kind) slice)
{
import mir.ndslice.topology : ipack;
import mir.ndslice.internal : DimensionsCountCTError;
Expand Down Expand Up @@ -4240,6 +4242,171 @@ version(mir_test) unittest
assert(x == slice);
}

/++
Constructs a new view of an n-dimensional slice with dimension `axis` removed.
Throws:
`AssertError` if the length of the corresponding dimension doesn' equal 1.
Params:
axis = dimension to remove, if it is single-dimensional
slice = n-dimensional slice
Returns:
new view of a slice with dimension removed
See_also: $(LREF unsqueeze), $(LREF iota).
+/
Slice!(Iterator, N - 1, kind != Canonical ? kind : axis == 0 ? Universal : N == 2 ? Contiguous : kind)
squeeze(sizediff_t axis = 0, Iterator, size_t N, SliceKind kind)
(Slice!(Iterator, N, kind) slice)
if (-sizediff_t(N) <= axis && axis < sizediff_t(N) && N > 1)
in {
assert(slice._lengths[axis < 0 ? N + axis : axis] == 1);
}
do {
import mir.utility: swap;
enum sizediff_t a = axis < 0 ? N + axis : axis;
typeof(return) ret;
foreach (i; 0 .. a)
ret._lengths[i] = slice._lengths[i];
foreach (i; a + 1 .. N)
ret._lengths[i - 1] = slice._lengths[i];
static if (kind == Universal)
{
foreach (i; 0 .. a)
ret._strides[i] = slice._strides[i];
foreach (i; a + 1.. N)
ret._strides[i - 1] = slice._strides[i];
}
else
static if (kind == Canonical)
{
static if (a == 0)
{
foreach (i; 0 .. N - 1)
ret._strides[i] = slice._strides[i];
}
else
{
foreach (i; 0 .. a - 1)
ret._strides[i] = slice._strides[i];
foreach (i; a .. N - 1)
ret._strides[i - 1] = slice._strides[i];
}
}
swap(ret._iterator, slice._iterator);
return ret;
}

///
unittest
{
import mir.ndslice.topology : iota;
import mir.ndslice.allocation : slice;

// [[0, 1, 2]] -> [0, 1, 2]
assert([1, 3].iota.squeeze == [3].iota);
// [[0], [1], [2]] -> [0, 1, 2]
assert([3, 1].iota.squeeze!1 == [3].iota);
assert([3, 1].iota.squeeze!(-1) == [3].iota);

assert([1, 3].iota.canonical.squeeze == [3].iota);
assert([3, 1].iota.canonical.squeeze!1 == [3].iota);
assert([3, 1].iota.canonical.squeeze!(-1) == [3].iota);

assert([1, 3].iota.universal.squeeze == [3].iota);
assert([3, 1].iota.universal.squeeze!1 == [3].iota);
assert([3, 1].iota.universal.squeeze!(-1) == [3].iota);

assert([1, 3, 4].iota.squeeze == [3, 4].iota);
assert([3, 1, 4].iota.squeeze!1 == [3, 4].iota);
assert([3, 4, 1].iota.squeeze!(-1) == [3, 4].iota);

assert([1, 3, 4].iota.canonical.squeeze == [3, 4].iota);
assert([3, 1, 4].iota.canonical.squeeze!1 == [3, 4].iota);
assert([3, 4, 1].iota.canonical.squeeze!(-1) == [3, 4].iota);

assert([1, 3, 4].iota.universal.squeeze == [3, 4].iota);
assert([3, 1, 4].iota.universal.squeeze!1 == [3, 4].iota);
assert([3, 4, 1].iota.universal.squeeze!(-1) == [3, 4].iota);
}

/++
Constructs a view of an n-dimensional slice with a dimension added at `axis`. Used
to unsqueeze a squeezed slice.
Params:
slice = n-dimensional slice
axis = dimension to be unsqueezed (add new dimension), default values is 0, the first dimension
Returns:
unsqueezed n+1-dimensional slice of the same slice kind
See_also: $(LREF squeeze), $(LREF iota).
+/
Slice!(Iterator, N + 1, kind) unsqueeze(Iterator, size_t N, SliceKind kind)
(Slice!(Iterator, N, kind) slice, sizediff_t axis = 0)
in {
assert(-sizediff_t(N + 1) <= axis && axis <= sizediff_t(N));
}
do {
import mir.utility: swap;
typeof(return) ret;
if (axis < 0)
{
axis += N + 1;
}
foreach (i; 0 .. axis)
ret._lengths[i] = slice._lengths[i];
ret._lengths[axis] = 1;
foreach (i; axis .. N)
ret._lengths[i + 1] = slice._lengths[i];
static if (kind == Universal)
{
foreach (i; 0 .. axis)
ret._strides[i] = slice._strides[i];
foreach (i; axis .. N)
ret._strides[i + 1] = slice._strides[i];
}
else
static if (kind == Canonical)
{
if (axis == 0)
{
ret._strides[0] = 1;
foreach (i; 1 .. N)
ret._strides[i] = slice._strides[i - 1];
}
else
{
foreach (i; 1 .. axis)
ret._strides[i - 1] = slice._strides[i - 1];
foreach (i; axis .. N)
ret._strides[i + 0] = slice._strides[i - 1];
}
}
swap(ret._iterator, slice._iterator);
return ret;
}

///
version (mir_test)
@safe pure nothrow @nogc
unittest
{
// [0, 1, 2] -> [[0, 1, 2]]
assert([3].iota.unsqueeze == [1, 3].iota);

assert([3].iota.universal.unsqueeze == [1, 3].iota);
assert([3, 4].iota.unsqueeze == [1, 3, 4].iota);
assert([3, 4].iota.canonical.unsqueeze == [1, 3, 4].iota);
assert([3, 4].iota.universal.unsqueeze == [1, 3, 4].iota);

// [0, 1, 2] -> [[0], [1], [2]]
assert([3].iota.unsqueeze(-1) == [3, 1].iota);

assert([3].iota.universal.unsqueeze(-1) == [3, 1].iota);
assert([3, 4].iota.unsqueeze(-1) == [3, 4, 1].iota);
assert([3, 4].iota.canonical.unsqueeze(-1) == [3, 4, 1].iota);
assert([3, 4].iota.universal.unsqueeze(-1) == [3, 4, 1].iota);
}

/++
Field (element's member) projection.
Expand Down

0 comments on commit cdc31b8

Please sign in to comment.