Package org.flag4j.core
Class Shape
java.lang.Object
org.flag4j.core.Shape
- All Implemented Interfaces:
Serializable
An object to store the shape of a tensor. Shapes are immutable.
- See Also:
-
Field Summary
Modifier and TypeFieldDescriptionprivate final int[]
An array containing the size of each dimension of this shape.private int[]
An array containing the strides of all dimensions within this shape.private BigInteger
Total entries of this shape. -
Constructor Summary
-
Method Summary
Modifier and TypeMethodDescriptionint[]
Constructs strides for each dimension of this shape as if for a newly constructed tensor.int
entriesIndex
(int... indices) Computes the index of the 1D data array for a dense tensor from tensor indices with this shape.boolean
Checks if an object is equal to this shape.int
get
(int i) Get the size of the shape object in the specified dimension.int[]
getDims()
Gets the shape of a tensor as an array.int[]
getIndices
(int index) Computes the nD tensor indices based on an index from the internal 1D data array.void
getNextIndices
(int[] currentIndices, int i) Gets the next indices for a tensor with this shape.int
getRank()
Gets the rank of a tensor with this shape.int[]
Gets the shape of a tensor as an array.int
hashCode()
Generates the hashcode for this shape object.void
If strides are null, create them.swapAxes
(int... axes) Permutes the axes of this shape.swapAxes
(int axis1, int axis2) Swaps two axes of this shape.toString()
Converts this Shape object to a string format.Gets the total number of entries for a tensor with this shape.
-
Field Details
-
dims
private final int[] dimsAn array containing the size of each dimension of this shape. -
strides
private int[] stridesAn array containing the strides of all dimensions within this shape. -
totalEntries
Total entries of this shape. This is only computed on demand bytotalEntries()
-
-
Constructor Details
-
Shape
public Shape(int... dims) Constructs a shape object from specified dimension measurements.- Parameters:
dims
- A list of the dimension measurements for this shape object. All entries must be non-negative.- Throws:
IllegalArgumentException
- If any dimension is negative.
-
Shape
public Shape(boolean computeStrides, int... dims) Constructs a shape object from specified dimension measurements.- Parameters:
computeStrides
- Flag indicating if shape strides should be computed.dims
- A list of the dimension measurements for this shape object. All entries must be non-negative.- Throws:
IllegalArgumentException
- If any dimension is negative.
-
-
Method Details
-
getRank
public int getRank()Gets the rank of a tensor with this shape.- Returns:
- The rank for a tensor with this shape.
-
getDims
public int[] getDims()Gets the shape of a tensor as an array.- Returns:
- Shape of a tensor as an integer array.
-
getStrides
public int[] getStrides()Gets the shape of a tensor as an array.- Returns:
- Shape of a tensor as an integer array.
-
get
public int get(int i) Get the size of the shape object in the specified dimension.- Parameters:
i
- Dimension to get the size of.- Returns:
- The size of this shape object in the specified dimension.
-
createNewStrides
public int[] createNewStrides()Constructs strides for each dimension of this shape as if for a newly constructed tensor. i.e. Strides will be a monotonically decreasing sequence with the last stride being 1.- Returns:
- The strides for all dimensions of a newly constructed tensor with this shape.
-
makeStridesIfNull
public void makeStridesIfNull()If strides are null, create them. Otherwise, do nothing. -
entriesIndex
public int entriesIndex(int... indices) Computes the index of the 1D data array for a dense tensor from tensor indices with this shape.- Parameters:
indices
- Indices of tensor with this shape.- Returns:
- The index of the element at the specified indices in the 1D data array of a dense tensor.
- Throws:
IllegalArgumentException
- If the number of indices does not match the rank of this shape.IndexOutOfBoundsException
- If any index does not fit within a tensor with this shape.
-
getIndices
public int[] getIndices(int index) Computes the nD tensor indices based on an index from the internal 1D data array.- Parameters:
index
- Index of internal 1D data array.- Returns:
- The multidimensional indices corresponding to the 1D data array index. This will be an array of integers with size equal to the rank of this shape.
-
swapAxes
Swaps two axes of this shape. New strides are constructed for this shape.- Parameters:
axis1
- First axis to swap.axis2
- Second axis to swap.- Returns:
- A copy of this shape with the specified axis swapped.
- Throws:
ArrayIndexOutOfBoundsException
- If either axis is not within [0,rank
-1].
-
swapAxes
Permutes the axes of this shape.- Parameters:
axes
- New axes permutation for the shape. This must be a permutation of{1, 2, 3, ... N}
whereN
is the rank of this shape.- Returns:
- Returns this shape.
- Throws:
ArrayIndexOutOfBoundsException
- Ifaxes
is not a permutation of{1, 2, 3, ... N}
.
-
totalEntries
Gets the total number of entries for a tensor with this shape.- Returns:
- The total number of entries for a tensor with this shape.
-
equals
-
getNextIndices
public void getNextIndices(int[] currentIndices, int i) Gets the next indices for a tensor with this shape.- Parameters:
currentIndices
- Current indices. This array is modified.i
- Index of 1d data array.
-
hashCode
public int hashCode()Generates the hashcode for this shape object. This is computed by passing the dims array of this shape object toArrays.hashCode(int[])
. -
toString
-