Skip to content

Stack and concatenate #844

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Oct 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/doc/ndarray_for_numpy_users/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,7 @@
//! `a[:] = 3.` | [`a.fill(3.)`][.fill()] | set all array elements to the same scalar value
//! `a[:] = b` | [`a.assign(&b)`][.assign()] | copy the data from array `b` into array `a`
//! `np.concatenate((a,b), axis=1)` | [`stack![Axis(1), a, b]`][stack!] or [`stack(Axis(1), &[a.view(), b.view()])`][stack()] | concatenate arrays `a` and `b` along axis 1
//! `np.stack((a,b), axis=1)` | [`stack_new_axis![Axis(1), a, b]`][stack_new_axis!] or [`stack_new_axis(Axis(1), vec![a.view(), b.view()])`][stack_new_axis()] | stack arrays `a` and `b` along axis 1
//! `a[:,np.newaxis]` or `np.expand_dims(a, axis=1)` | [`a.insert_axis(Axis(1))`][.insert_axis()] | create an array from `a`, inserting a new axis 1
//! `a.transpose()` or `a.T` | [`a.t()`][.t()] or [`a.reversed_axes()`][.reversed_axes()] | transpose of array `a` (view for `.t()` or by-move for `.reversed_axes()`)
//! `np.diag(a)` | [`a.diag()`][.diag()] | view the diagonal of `a`
Expand Down Expand Up @@ -642,6 +643,8 @@
//! [.shape()]: ../../struct.ArrayBase.html#method.shape
//! [stack!]: ../../macro.stack.html
//! [stack()]: ../../fn.stack.html
//! [stack_new_axis!]: ../../macro.stack_new_axis.html
//! [stack_new_axis()]: ../../fn.stack_new_axis.html
//! [.strides()]: ../../struct.ArrayBase.html#method.strides
//! [.index_axis()]: ../../struct.ArrayBase.html#method.index_axis
//! [.sum_axis()]: ../../struct.ArrayBase.html#method.sum_axis
Expand Down
4 changes: 2 additions & 2 deletions src/impl_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use crate::iter::{
IndexedIter, IndexedIterMut, Iter, IterMut, Lanes, LanesMut, Windows,
};
use crate::slice::MultiSlice;
use crate::stacking::stack;
use crate::stacking::concatenate;
use crate::{NdIndex, Slice, SliceInfo, SliceOrIndex};

/// # Methods For All Array Types
Expand Down Expand Up @@ -840,7 +840,7 @@ where
dim.set_axis(axis, 0);
unsafe { Array::from_shape_vec_unchecked(dim, vec![]) }
} else {
stack(axis, &subs).unwrap()
concatenate(axis, &subs).unwrap()
}
}

Expand Down
4 changes: 3 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,9 @@ use crate::iterators::{ElementsBase, ElementsBaseMut, Iter, IterMut, Lanes, Lane

pub use crate::arraytraits::AsArray;
pub use crate::linalg_traits::{LinalgScalar, NdFloat};
pub use crate::stacking::stack;

#[allow(deprecated)]
pub use crate::stacking::{concatenate, stack, stack_new_axis};

pub use crate::impl_views::IndexLonger;
pub use crate::shape_builder::ShapeBuilder;
Expand Down
181 changes: 175 additions & 6 deletions src/stacking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
use crate::error::{from_kind, ErrorKind, ShapeError};
use crate::imp_prelude::*;

/// Stack arrays along the given axis.
/// Concatenate arrays along the given axis.
///
/// ***Errors*** if the arrays have mismatching shapes, apart from along `axis`.
/// (may be made more flexible in the future).<br>
Expand All @@ -29,10 +29,11 @@ use crate::imp_prelude::*;
/// [3., 3.]]))
/// );
/// ```
pub fn stack<'a, A, D>(
axis: Axis,
arrays: &[ArrayView<'a, A, D>],
) -> Result<Array<A, D>, ShapeError>
#[deprecated(
since = "0.13.2",
note = "Please use the `concatenate` function instead"
)]
pub fn stack<A, D>(axis: Axis, arrays: &[ArrayView<A, D>]) -> Result<Array<A, D>, ShapeError>
where
A: Copy,
D: RemoveAxis,
Expand Down Expand Up @@ -76,7 +77,103 @@ where
Ok(res)
}

/// Stack arrays along the given axis.
/// Concatenate arrays along the given axis.
///
/// ***Errors*** if the arrays have mismatching shapes, apart from along `axis`.
/// (may be made more flexible in the future).<br>
/// ***Errors*** if `arrays` is empty, if `axis` is out of bounds,
/// if the result is larger than is possible to represent.
///
/// ```
/// use ndarray::{arr2, Axis, concatenate};
///
/// let a = arr2(&[[2., 2.],
/// [3., 3.]]);
/// assert!(
/// concatenate(Axis(0), &[a.view(), a.view()])
/// == Ok(arr2(&[[2., 2.],
/// [3., 3.],
/// [2., 2.],
/// [3., 3.]]))
/// );
/// ```
#[allow(deprecated)]
pub fn concatenate<A, D>(axis: Axis, arrays: &[ArrayView<A, D>]) -> Result<Array<A, D>, ShapeError>
where
A: Copy,
D: RemoveAxis,
{
stack(axis, arrays)
}

/// Stack arrays along the new axis.
///
/// ***Errors*** if the arrays have mismatching shapes.
/// ***Errors*** if `arrays` is empty, if `axis` is out of bounds,
/// if the result is larger than is possible to represent.
///
/// ```
/// extern crate ndarray;
///
/// use ndarray::{arr2, arr3, stack_new_axis, Axis};
///
/// # fn main() {
///
/// let a = arr2(&[[2., 2.],
/// [3., 3.]]);
/// assert!(
/// stack_new_axis(Axis(0), &[a.view(), a.view()])
/// == Ok(arr3(&[[[2., 2.],
/// [3., 3.]],
/// [[2., 2.],
/// [3., 3.]]]))
/// );
/// # }
/// ```
pub fn stack_new_axis<A, D>(
axis: Axis,
arrays: &[ArrayView<A, D>],
) -> Result<Array<A, D::Larger>, ShapeError>
where
A: Copy,
D: Dimension,
D::Larger: RemoveAxis,
{
if arrays.is_empty() {
return Err(from_kind(ErrorKind::Unsupported));
}
let common_dim = arrays[0].raw_dim();
// Avoid panic on `insert_axis` call, return an Err instead of it.
if axis.index() > common_dim.ndim() {
return Err(from_kind(ErrorKind::OutOfBounds));
}
let mut res_dim = common_dim.insert_axis(axis);

if arrays.iter().any(|a| a.raw_dim() != common_dim) {
return Err(from_kind(ErrorKind::IncompatibleShape));
}

res_dim.set_axis(axis, arrays.len());

// we can safely use uninitialized values here because they are Copy
// and we will only ever write to them
let size = res_dim.size();
let mut v = Vec::with_capacity(size);
unsafe {
v.set_len(size);
}
let mut res = Array::from_shape_vec(res_dim, v)?;

res.axis_iter_mut(axis)
.zip(arrays.into_iter())
.for_each(|(mut assign_view, array)| {
assign_view.assign(&array);
});

Ok(res)
}

/// Concatenate arrays along the given axis.
///
/// Uses the [`stack`][1] function, calling `ArrayView::from(&a)` on each
/// argument `a`.
Expand All @@ -103,9 +200,81 @@ where
/// );
/// # }
/// ```
#[deprecated(
since = "0.13.2",
note = "Please use the `concatenate!` macro instead"
)]
#[macro_export]
macro_rules! stack {
($axis:expr, $( $array:expr ),+ ) => {
$crate::stack($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap()
}
}

/// Concatenate arrays along the given axis.
///
/// Uses the [`concatenate`][1] function, calling `ArrayView::from(&a)` on each
/// argument `a`.
///
/// [1]: fn.concatenate.html
///
/// ***Panics*** if the `concatenate` function would return an error.
///
/// ```
/// extern crate ndarray;
///
/// use ndarray::{arr2, concatenate, Axis};
///
/// # fn main() {
///
/// let a = arr2(&[[2., 2.],
/// [3., 3.]]);
/// assert!(
/// concatenate![Axis(0), a, a]
/// == arr2(&[[2., 2.],
/// [3., 3.],
/// [2., 2.],
/// [3., 3.]])
/// );
/// # }
/// ```
#[macro_export]
macro_rules! concatenate {
($axis:expr, $( $array:expr ),+ ) => {
$crate::concatenate($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap()
}
}

/// Stack arrays along the new axis.
///
/// Uses the [`stack_new_axis`][1] function, calling `ArrayView::from(&a)` on each
/// argument `a`.
///
/// [1]: fn.stack_new_axis.html
///
/// ***Panics*** if the `stack` function would return an error.
///
/// ```
/// extern crate ndarray;
///
/// use ndarray::{arr2, arr3, stack_new_axis, Axis};
///
/// # fn main() {
///
/// let a = arr2(&[[2., 2.],
/// [3., 3.]]);
/// assert!(
/// stack_new_axis![Axis(0), a, a]
/// == arr3(&[[[2., 2.],
/// [3., 3.]],
/// [[2., 2.],
/// [3., 3.]]])
/// );
/// # }
/// ```
#[macro_export]
macro_rules! stack_new_axis {
($axis:expr, $( $array:expr ),+ ) => {
$crate::stack_new_axis($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap()
}
}
45 changes: 43 additions & 2 deletions tests/stacking.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use ndarray::{arr2, aview1, stack, Array2, Axis, ErrorKind};
#![allow(deprecated)]

use ndarray::{arr2, arr3, aview1, concatenate, stack, Array2, Axis, ErrorKind, Ix1};

#[test]
fn stacking() {
fn concatenating() {
let a = arr2(&[[2., 2.], [3., 3.]]);
let b = ndarray::stack(Axis(0), &[a.view(), a.view()]).unwrap();
assert_eq!(b, arr2(&[[2., 2.], [3., 3.], [2., 2.], [3., 3.]]));
Expand All @@ -23,4 +25,43 @@ fn stacking() {

let res: Result<Array2<f64>, _> = ndarray::stack(Axis(0), &[]);
assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported);

let a = arr2(&[[2., 2.], [3., 3.]]);
let b = ndarray::concatenate(Axis(0), &[a.view(), a.view()]).unwrap();
assert_eq!(b, arr2(&[[2., 2.], [3., 3.], [2., 2.], [3., 3.]]));

let c = concatenate![Axis(0), a, b];
assert_eq!(
c,
arr2(&[[2., 2.], [3., 3.], [2., 2.], [3., 3.], [2., 2.], [3., 3.]])
);

let d = concatenate![Axis(0), a.row(0), &[9., 9.]];
assert_eq!(d, aview1(&[2., 2., 9., 9.]));

let res = ndarray::concatenate(Axis(1), &[a.view(), c.view()]);
assert_eq!(res.unwrap_err().kind(), ErrorKind::IncompatibleShape);

let res = ndarray::concatenate(Axis(2), &[a.view(), c.view()]);
assert_eq!(res.unwrap_err().kind(), ErrorKind::OutOfBounds);

let res: Result<Array2<f64>, _> = ndarray::concatenate(Axis(0), &[]);
assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported);
}

#[test]
fn stacking() {
let a = arr2(&[[2., 2.], [3., 3.]]);
let b = ndarray::stack_new_axis(Axis(0), &[a.view(), a.view()]).unwrap();
assert_eq!(b, arr3(&[[[2., 2.], [3., 3.]], [[2., 2.], [3., 3.]]]));

let c = arr2(&[[3., 2., 3.], [2., 3., 2.]]);
let res = ndarray::stack_new_axis(Axis(1), &[a.view(), c.view()]);
assert_eq!(res.unwrap_err().kind(), ErrorKind::IncompatibleShape);

let res = ndarray::stack_new_axis(Axis(3), &[a.view(), a.view()]);
assert_eq!(res.unwrap_err().kind(), ErrorKind::OutOfBounds);

let res: Result<Array2<f64>, _> = ndarray::stack_new_axis::<_, Ix1>(Axis(0), &[]);
assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported);
}