Class CooTensorDot<T extends Semiring<T>>

java.lang.Object
org.flag4j.linalg.ops.TensorDot<T[]>
org.flag4j.linalg.ops.sparse.coo.CooTensorDot<T>
Type Parameters:
T - The type of semiring that the elements of the tensors in the dot product belong to.

public class CooTensorDot<T extends Semiring<T>> extends TensorDot<T[]>
Instances of this class can be used to compute the tensor dot product between two sparse COO tensors.
  • Constructor Details

    • CooTensorDot

      public CooTensorDot(Shape shape1, T[] src1, int[][] indices1, Shape shape2, T[] src2, int[][] indices2, int[] src1Axes, int[] src2Axes)
      Constructs a tensor dot product problem for computing the tensor contraction of two tensors over the specified set of axes. That is, computes the sum of products between the two tensors along the specified set of axes.
      Parameters:
      shape1 - Shape of the first tensor in the contraction.
      src1 - Non-zero data of the first tensor in the contraction.
      indices1 - Non-zero indices of the first tensor in the contraction.
      shape2 - Shape of the second tensor in the contraction.
      src2 - Non-zero data of the second tensor in the contraction.
      indices2 - Non-zero indices of the second tensor in the contraction.
      src1Axes - Axes along which to compute products for src1 tensor.
      src2Axes - Axes along which to compute products for src2 tensor.
      Throws:
      IllegalArgumentException - If src1Axes and src2Axes do not match in length, or if any of the axes are out of bounds for the corresponding tensor. Or, If the two tensors shapes do not match along the specified axes pairwise in src1Axes and src2Axes.
  • Method Details

    • compute

      public void compute(T[] dest)
      Computes this tensor dot product as specified in the constructor.
      Specified by:
      compute in class TensorDot<T extends Semiring<T>[]>
      Parameters:
      dest - The array to store the data of the dense tensor resulting from this tensor dot product. The size of this array should be computed using TensorDot.getOutputSize().