Class Shape
- All Implemented Interfaces:
Serializable
A shape is defined by an array of dimensions, where each dimension specifies the size of the tensor along a particular axis.
Strides
can also be computed for the shape which specify the number of data to step in each dimension of
the shape when traversing an array with the given shape. Strides will always be row-major contiguous and allow for efficient
array traversal and mapping of nD indices to 1D contiguous indices.
This class also supports converting between multidimensional and flat indices, computing the shapes rank (i.e. number of dimensions), computing the total number of data of an array with the given shape, and manipulating dimensions through swaps or permutations.
The Shape
class is immutable with respect to its dimensions, ensuring thread safety and consistency. Strides
are computed lazily only when needed to minimize overhead.
This class is a fundamental building block for tensor ops, particularly in contexts where multidimensional indexing and dimension manipulations are required.
Example usage:
Shape shape = new Shape(); // Creates a shape for a scalar value.
shape = new Shape(3, 4, 5); // Creates a shape for a 3x4x5 tensor.
int rank = shape.getRank(); // Gets the rank (number of dimensions).
int[] strides = shape.getStrides(); // Retrieves the strides for this shape.
int flatIndex = shape.entriesIndex(2, 1, 4); // Converts multidimensional indices to a flat index.
int[] multiDimIndex = shape.
- See Also:
-
Constructor Summary
Constructors -
Method Summary
Modifier and TypeMethodDescriptionboolean
Checks if an object is equal to this shape.flatten()
Flattens this shape to a rank-1 shape with dimension equal to the product of all of this shape's dimensions.int
get
(int i) Get the size of the shape object in the specified dimension.int[]
getDims()
Gets the shape of a tensor as an array of dimensions.int
getFlatIndex
(int... indices) Computes the index of the 1D data array for a dense tensor from nD indices for a tensor with this shape.int[]
getNdIndices
(int index) Efficiently computes the nD tensor indices based on an index from the internal 1D data array.int
getRank()
Gets the rank of a tensor with this shape.int[]
Gets the strides of this shape as an array.int
hashCode()
Generates the hashcode for this shape object.permuteAxes
(int... axes) Permutes the axes of this shape.slice
(int startIdx) Returns a slice of this shape starting from the specified index to the end of this shape's dimensions.slice
(int startIdx, int stopIdx) Returns a slice of this shape from the specified start index to the stop index of this shape's dimensions.swapAxes
(int axis1, int axis2) Swaps two axes of this shape.toString()
Converts this Shape object to a string format.Gets the total number of data for a tensor with this shape.int
Gets the total number of data for a tensor with this shape.int
unsafeGetFlatIndex
(int... indices) Computes the index of the 1D data array for a dense tensor from nD indices for a tensor with this shape.unsafePermuteAxes
(int... axes) Permutes the axes of this shape.
-
Constructor Details
-
Shape
public Shape(int... dims) Constructs a shape object from specified dimensions.- Parameters:
dims
- A list of the dimension measurements for this shape object. All data must be non-negative.- Throws:
IllegalArgumentException
- If any dimension is negative.
-
-
Method Details
-
getRank
public int getRank()Gets the rank of a tensor with this shape.- Returns:
- The rank for a tensor with this shape.
-
getDims
public int[] getDims()Gets the shape of a tensor as an array of dimensions.- Returns:
- Shape of a tensor as an integer array.
-
getStrides
public int[] getStrides()Gets the strides of this shape as an array. Strides are the step sizes needed to move from one element to another along each axis in the tensor.- Returns:
- The strides of this shape as an integer array.
-
get
public int get(int i) Get the size of the shape object in the specified dimension.- Parameters:
i
- Dimension to get the size of.- Returns:
- The size of this shape object in the specified dimension.
-
slice
Returns a slice of this shape starting from the specified index to the end of this shape's dimensions.- Parameters:
startIdx
- The starting index for slicing (inclusive).- Returns:
- A new
Shape
object containing the dimensions fromstartIdx
to the end dimension. - Throws:
IndexOutOfBoundsException
- IfstartIdx
is out of bounds of the rank of this shape.
-
slice
Returns a slice of this shape from the specified start index to the stop index of this shape's dimensions.- Parameters:
startIdx
- The starting index for slicing (inclusive).stopIdx
- The stopping index for slicing (exclusive).- Returns:
- A new
Shape
object containing the dimensions fromstartIdx
tostopIdx
. - Throws:
IndexOutOfBoundsException
- IfstartIdx
orstopIdx
is out of bounds.IllegalArgumentException
- IfstartIdx > stopIdx
.
-
flatten
Flattens this shape to a rank-1 shape with dimension equal to the product of all of this shape's dimensions.- Returns:
- A rank-1 shape with dimension equal to the product of all of this shape's dimensions.
- Throws:
ArithmeticException
- If the product of this shape's dimensions is too large to be stored in a 32-bit integer.
-
getFlatIndex
public int getFlatIndex(int... indices) Computes the index of the 1D data array for a dense tensor from nD indices for a tensor with this shape.- Parameters:
indices
- Indices of tensor with this shape.- Returns:
- The index of the element at the specified indices in the 1D data array of a dense tensor.
- Throws:
IllegalArgumentException
- If the number of indices does not match the rank of this shape.IndexOutOfBoundsException
- If any index does not fit within a tensor with this shape.- See Also:
-
unsafeGetFlatIndex
public int unsafeGetFlatIndex(int... indices) Computes the index of the 1D data array for a dense tensor from nD indices for a tensor with this shape.
Warning: Unlike
getFlatIndex(int...)
, this method does not perform bounds checking on indices. This can lead to exceptions being thrown or possibly no exception but incorrect results ifindices
are not valid indices.- Parameters:
indices
- Indices of tensor with this shape.- Returns:
- The index of the element at the specified indices in the 1D data array of a dense tensor.
- Throws:
IllegalArgumentException
- If the number of indices does not match the rank of this shape.IndexOutOfBoundsException
- If any index does not fit within a tensor with this shape.- See Also:
-
getNdIndices
public int[] getNdIndices(int index) Efficiently computes the nD tensor indices based on an index from the internal 1D data array.- Parameters:
index
- Index of internal 1D data array.- Returns:
- The multidimensional indices corresponding to the 1D data array index. This will be an array of integers
with length equal to the
rank
of this shape.
-
swapAxes
Swaps two axes of this shape. If this shape has had its strides computed, then new strides will also be computed for the resulting shape.- Parameters:
axis1
- First axis to swap.axis2
- Second axis to swap.- Returns:
- A copy of this shape with the specified axis swapped.
- Throws:
ArrayIndexOutOfBoundsException
- If either axis is not within [0,rank
-1].- See Also:
-
permuteAxes
Permutes the axes of this shape.- Parameters:
axes
- New axes permutation for the shape. This must be a permutation of{1, 2, 3, ... N}
whereN
is the rank of this shape.- Returns:
- Returns this shape.
- Throws:
ArrayIndexOutOfBoundsException
- Ifaxes
is not a permutation of{1, 2, 3, ... N}
.- See Also:
-
unsafePermuteAxes
Permutes the axes of this shape.
Warning: Unlike
permuteAxes(int...)
, this method does not perform bounds checking onaxes
or ensure thataxes
is a permutation of{1, 2, 3, ... n}
. This may result in unexpected behavior iftempDims
is malformed.- Parameters:
axes
- New axes permutation for the shape. This must be a permutation of{1, 2, 3, ... n}
wheren
is the rank of this shape.- Returns:
- Returns this shape.
- See Also:
-
totalEntries
Gets the total number of data for a tensor with this shape.- Returns:
- The total number of data for a tensor with this shape.
-
totalEntriesIntValueExact
public int totalEntriesIntValueExact()Gets the total number of data for a tensor with this shape. If the total number of data exceeds Integer.MAX_VALUE, an exception is thrown.
This method is likely to be more efficient than
totalEntries()
if a primitive int value is desired.- Returns:
- The total number of data for a tensor with this shape.
- Throws:
ArithmeticException
- If the total number of data overflows a primitive int.
-
equals
-
hashCode
public int hashCode()Generates the hashcode for this shape object. This is computed by passing the dims array of this shape object toArrays.hashCode(int[])
. -
toString
-