Class Shape

java.lang.Object
org.flag4j.arrays.Shape
All Implemented Interfaces:
Serializable

public class Shape extends Object implements Serializable
Represents the shape of a multidimensional array (e.g. tensor, matrix, vector, etc.), specifying its dimensions and providing utilities for shape-related ops.

A shape is defined by an array of dimensions, where each dimension specifies the size of the tensor along a particular axis. Strides can also be computed for the shape which specify the number of data to step in each dimension of the shape when traversing an array with the given shape. Strides will always be row-major contiguous and allow for efficient array traversal and mapping of nD indices to 1D contiguous indices.

This class also supports converting between multidimensional and flat indices, computing the shapes rank (i.e. number of dimensions), computing the total number of data of an array with the given shape, and manipulating dimensions through swaps or permutations.

The Shape class is immutable with respect to its dimensions, ensuring thread safety and consistency. Strides are computed lazily only when needed to minimize overhead.

This class is a fundamental building block for tensor ops, particularly in contexts where multidimensional indexing and dimension manipulations are required.

Example usage:


 Shape shape = new Shape(); // Creates a shape for a scalar value.
 shape = new Shape(3, 4, 5); // Creates a shape for a 3x4x5 tensor.
 int rank = shape.getRank(); // Gets the rank (number of dimensions).
 int[] strides = shape.getStrides(); // Retrieves the strides for this shape.
 int flatIndex = shape.entriesIndex(2, 1, 4); // Converts multidimensional indices to a flat index.
 int[] multiDimIndex = shape.
 
See Also:
  • Constructor Summary

    Constructors
    Constructor
    Description
    Shape(int... dims)
    Constructs a shape object from specified dimensions.
  • Method Summary

    Modifier and Type
    Method
    Description
    boolean
    Checks if an object is equal to this shape.
    Flattens this shape to a rank-1 shape with dimension equal to the product of all of this shape's dimensions.
    int
    get(int i)
    Get the size of the shape object in the specified dimension.
    int[]
    Gets the shape of a tensor as an array of dimensions.
    int
    getFlatIndex(int... indices)
    Computes the index of the 1D data array for a dense tensor from nD indices for a tensor with this shape.
    int[]
    getNdIndices(int index)
    Efficiently computes the nD tensor indices based on an index from the internal 1D data array.
    int
    Gets the rank of a tensor with this shape.
    int[]
    Gets the strides of this shape as an array.
    int
    Generates the hashcode for this shape object.
    permuteAxes(int... axes)
    Permutes the axes of this shape.
    slice(int startIdx)
    Returns a slice of this shape starting from the specified index to the end of this shape's dimensions.
    slice(int startIdx, int stopIdx)
    Returns a slice of this shape from the specified start index to the stop index of this shape's dimensions.
    swapAxes(int axis1, int axis2)
    Swaps two axes of this shape.
    Converts this Shape object to a string format.
    Gets the total number of data for a tensor with this shape.
    int
    Gets the total number of data for a tensor with this shape.
    int
    unsafeGetFlatIndex(int... indices)
    Computes the index of the 1D data array for a dense tensor from nD indices for a tensor with this shape.
    unsafePermuteAxes(int... axes)
    Permutes the axes of this shape.

    Methods inherited from class java.lang.Object

    clone, finalize, getClass, notify, notifyAll, wait, wait, wait
  • Constructor Details

    • Shape

      public Shape(int... dims)
      Constructs a shape object from specified dimensions.
      Parameters:
      dims - A list of the dimension measurements for this shape object. All data must be non-negative.
      Throws:
      IllegalArgumentException - If any dimension is negative.
  • Method Details

    • getRank

      public int getRank()
      Gets the rank of a tensor with this shape.
      Returns:
      The rank for a tensor with this shape.
    • getDims

      public int[] getDims()
      Gets the shape of a tensor as an array of dimensions.
      Returns:
      Shape of a tensor as an integer array.
    • getStrides

      public int[] getStrides()
      Gets the strides of this shape as an array. Strides are the step sizes needed to move from one element to another along each axis in the tensor.
      Returns:
      The strides of this shape as an integer array.
    • get

      public int get(int i)
      Get the size of the shape object in the specified dimension.
      Parameters:
      i - Dimension to get the size of.
      Returns:
      The size of this shape object in the specified dimension.
    • slice

      public Shape slice(int startIdx)
      Returns a slice of this shape starting from the specified index to the end of this shape's dimensions.
      Parameters:
      startIdx - The starting index for slicing (inclusive).
      Returns:
      A new Shape object containing the dimensions from startIdx to the end dimension.
      Throws:
      IndexOutOfBoundsException - If startIdx is out of bounds of the rank of this shape.
    • slice

      public Shape slice(int startIdx, int stopIdx)
      Returns a slice of this shape from the specified start index to the stop index of this shape's dimensions.
      Parameters:
      startIdx - The starting index for slicing (inclusive).
      stopIdx - The stopping index for slicing (exclusive).
      Returns:
      A new Shape object containing the dimensions from startIdx to stopIdx.
      Throws:
      IndexOutOfBoundsException - If startIdx or stopIdx is out of bounds.
      IllegalArgumentException - If startIdx > stopIdx.
    • flatten

      public Shape flatten()
      Flattens this shape to a rank-1 shape with dimension equal to the product of all of this shape's dimensions.
      Returns:
      A rank-1 shape with dimension equal to the product of all of this shape's dimensions.
      Throws:
      ArithmeticException - If the product of this shape's dimensions is too large to be stored in a 32-bit integer.
    • getFlatIndex

      public int getFlatIndex(int... indices)
      Computes the index of the 1D data array for a dense tensor from nD indices for a tensor with this shape.
      Parameters:
      indices - Indices of tensor with this shape.
      Returns:
      The index of the element at the specified indices in the 1D data array of a dense tensor.
      Throws:
      IllegalArgumentException - If the number of indices does not match the rank of this shape.
      IndexOutOfBoundsException - If any index does not fit within a tensor with this shape.
      See Also:
    • unsafeGetFlatIndex

      public int unsafeGetFlatIndex(int... indices)

      Computes the index of the 1D data array for a dense tensor from nD indices for a tensor with this shape.

      Warning: Unlike getFlatIndex(int...), this method does not perform bounds checking on indices. This can lead to exceptions being thrown or possibly no exception but incorrect results if indices are not valid indices.

      Parameters:
      indices - Indices of tensor with this shape.
      Returns:
      The index of the element at the specified indices in the 1D data array of a dense tensor.
      Throws:
      IllegalArgumentException - If the number of indices does not match the rank of this shape.
      IndexOutOfBoundsException - If any index does not fit within a tensor with this shape.
      See Also:
    • getNdIndices

      public int[] getNdIndices(int index)
      Efficiently computes the nD tensor indices based on an index from the internal 1D data array.
      Parameters:
      index - Index of internal 1D data array.
      Returns:
      The multidimensional indices corresponding to the 1D data array index. This will be an array of integers with length equal to the rank of this shape.
    • swapAxes

      public Shape swapAxes(int axis1, int axis2)
      Swaps two axes of this shape. If this shape has had its strides computed, then new strides will also be computed for the resulting shape.
      Parameters:
      axis1 - First axis to swap.
      axis2 - Second axis to swap.
      Returns:
      A copy of this shape with the specified axis swapped.
      Throws:
      ArrayIndexOutOfBoundsException - If either axis is not within [0, rank-1].
      See Also:
    • permuteAxes

      public Shape permuteAxes(int... axes)
      Permutes the axes of this shape.
      Parameters:
      axes - New axes permutation for the shape. This must be a permutation of {1, 2, 3, ... N} where N is the rank of this shape.
      Returns:
      Returns this shape.
      Throws:
      ArrayIndexOutOfBoundsException - If axes is not a permutation of {1, 2, 3, ... N}.
      See Also:
    • unsafePermuteAxes

      public Shape unsafePermuteAxes(int... axes)

      Permutes the axes of this shape.

      Warning: Unlike permuteAxes(int...), this method does not perform bounds checking on axes or ensure that axes is a permutation of {1, 2, 3, ... n}. This may result in unexpected behavior if tempDims is malformed.

      Parameters:
      axes - New axes permutation for the shape. This must be a permutation of {1, 2, 3, ... n} where n is the rank of this shape.
      Returns:
      Returns this shape.
      See Also:
    • totalEntries

      public BigInteger totalEntries()
      Gets the total number of data for a tensor with this shape.
      Returns:
      The total number of data for a tensor with this shape.
    • totalEntriesIntValueExact

      public int totalEntriesIntValueExact()

      Gets the total number of data for a tensor with this shape. If the total number of data exceeds Integer.MAX_VALUE, an exception is thrown.

      This method is likely to be more efficient than totalEntries() if a primitive int value is desired.

      Returns:
      The total number of data for a tensor with this shape.
      Throws:
      ArithmeticException - If the total number of data overflows a primitive int.
    • equals

      public boolean equals(Object b)
      Checks if an object is equal to this shape.
      Overrides:
      equals in class Object
      Parameters:
      b - Object to compare with this shape.
      Returns:
      True if d is a Shape object and equal to this shape.
    • hashCode

      public int hashCode()
      Generates the hashcode for this shape object. This is computed by passing the dims array of this shape object to Arrays.hashCode(int[]).
      Overrides:
      hashCode in class Object
      Returns:
      The hashcode for this array object.
    • toString

      public String toString()
      Converts this Shape object to a string format.
      Overrides:
      toString in class Object
      Returns:
      The string representation for this Shape object.