Class TensorDot<T>

java.lang.Object
org.flag4j.linalg.ops.TensorDot<T>
Type Parameters:
T - Type of the storage for the data of the tensor.rf
Direct Known Subclasses:
CooTensorDot, DenseSemiringTensorDot, RealDenseTensorDot

public abstract class TensorDot<T> extends Object
The base class for all classes whose instances may be used to compute a tensor dot product.
  • Field Details

    • shape1

      protected Shape shape1
    • shape2

      protected Shape shape2
    • src1

      protected T src1
    • src2

      protected T src2
    • src1Axes

      protected int[] src1Axes
    • src2Axes

      protected int[] src2Axes
    • newShape1

      protected Shape newShape1
    • newShape2

      protected Shape newShape2
    • destShape

      protected Shape destShape
    • destLength

      protected int destLength
    • src1NewAxes

      protected int[] src1NewAxes
    • src2NewAxes

      protected int[] src2NewAxes
    • src1Dims

      protected int[] src1Dims
    • src2Dims

      protected int[] src2Dims
  • Constructor Details

    • TensorDot

      protected TensorDot(Shape shape1, T src1, Shape shape2, T src2, int[] src1Axes, int[] src2Axes)
      Constructs a tensor dot product problem for computing the tensor contraction of two tensors over the specified set of axes. That is, computes the sum of products between the two tensors along the specified set of axes.
      Parameters:
      shape1 - Shape of the first tensor in the contraction.
      src1 - Entries/Non-zero data of the first tensor in the contraction.
      shape2 - Shape of the second tensor in the contraction.
      src2 - Entries/Non-zero data of the second tensor in the contraction.
      src1Axes - Axes along which to compute products for src1 tensor.
      src2Axes - Axes along which to compute products for src2 tensor.
      Throws:
      IllegalArgumentException - If src1Axes and src2Axes do not match in length, or if any of the axes are out of bounds for the corresponding tensor. Or, If the two tensors shapes do not match along the specified axes pairwise in src1Axes and src2Axes.
  • Method Details

    • getOutputShape

      public Shape getOutputShape()
      Gets the shape of the tensor resulting from this tensor dot product as specified in the constructor.
      Returns:
      The shape of the tensor resulting from this tensor dot product as specified in the constructor.
    • getOutputSize

      public int getOutputSize()
      Gets the total number of data in the tensor resulting from this tensor dot product as specified in the constructor.
      Returns:
      The total number of data in the tensor resulting from this tensor dot product as specified in the constructor.
    • compute

      public abstract void compute(T dest)
      Computes this tensor dot product as specified in the constructor.
      Parameters:
      dest - The array to store the data of the dense tensor resulting from this tensor dot product. The size of this array should be computed using getOutputSize().