Class LogisticRegression

java.lang.Object
com.jml.core.Model<double[][],​double[]>
com.jml.classifiers.LogisticRegression

public class LogisticRegression extends Model<double[][],​double[]>
A logistic regression model. Supports binary classification for multiple features.
Fits a logistic function f(x)=1/[ 1+e-w^Tx ] to a dataset by minimizing the binary cross-entropy function.
  • Constructor Summary

    Constructors
    Constructor
    Description
    Creates a logistic regression model.
    LogisticRegression​(double learningRate)
    Creates a logistic regression model.
    LogisticRegression​(double learningRate, int maxIterations)
    Creates a logistic regression model.
    LogisticRegression​(double learningRate, int maxIterations, double threshold)
    Creates a logistic regression model.
  • Method Summary

    Modifier and Type
    Method
    Description
    Model<double[][],​double[]>
    fit​(double[][] features, double[] targets)
    Fits or trains the model with the given features and targets.
    double[]
    Gets the loss history from the optimizer.
    linalg.Matrix
    Gets the parameters of the trained model.
    Forms a string of the important aspects of the model.
    same as toString()
    double[]
    predict​(double[][] features)
    Uses fitted/trained model to make prediction on single feature.
    linalg.Matrix
    predict​(linalg.Matrix X, linalg.Matrix w)
    Makes a prediction using a model by specifying the parameters of the model.
    void
    saveModel​(String filePath)
    Saves a trained model to the specified file path.
    Forms a string of the important aspects of the model.

    Methods inherited from class com.jml.core.Model

    load

    Methods inherited from class java.lang.Object

    equals, getClass, hashCode, notify, notifyAll, wait, wait, wait
  • Constructor Details

    • LogisticRegression

      public LogisticRegression()
      Creates a logistic regression model. The model will be fit using a stochastic gradient descent optimizer with specified learning rate. Defaults to a learning rate of 0.002, 1000 max iterations and a threshold of 0.5e-5.
    • LogisticRegression

      public LogisticRegression(double learningRate, int maxIterations, double threshold)
      Creates a logistic regression model. The model will be fit using a stochastic gradient descent optimizer with specified learning rate, max iterations, and threshold.
      Parameters:
      learningRate - Learning rate to use during optimization.
      maxIterations - Maximum iterations to run optimizer for.
      threshold - Threshold for stopping the optimizer. If the loss becomes less than this value, the optimizer will stop early.
    • LogisticRegression

      public LogisticRegression(double learningRate, int maxIterations)
      Creates a logistic regression model. The model will be fit using a stochastic gradient descent optimizer with specified learning rate, max iterations. Defaults to a threshold of 0.5e-5.
      Parameters:
      learningRate - Learning rate to use during optimization.
      maxIterations - Maximum iterations to run optimizer for.
    • LogisticRegression

      public LogisticRegression(double learningRate)
      Creates a logistic regression model. The model will be fit using a stochastic gradient descent optimizer with specified learning rate. Defaults to 1000 max iterations and a threshold of 0.5e-5.
      Parameters:
      learningRate - Learning rate to use during optimization.
  • Method Details

    • fit

      public Model<double[][],​double[]> fit(double[][] features, double[] targets)
      Fits or trains the model with the given features and targets.
      Specified by:
      fit in class Model<double[][],​double[]>
      Parameters:
      features - The features of the training set.
      targets - The targets of the training set.
      Returns:
      This. i.e. the trained model.
      Throws:
      IllegalArgumentException - Thrown if the features and targets are not correctly sized per the specification when the model was compiled.
    • predict

      public double[] predict(double[][] features)
      Uses fitted/trained model to make prediction on single feature.
      Specified by:
      predict in class Model<double[][],​double[]>
      Parameters:
      features - The features to make predictions on.
      Returns:
      The models predicted labels.
      Throws:
      IllegalArgumentException - Thrown if the features are not correctly sized per the specification when the model was compiled.
      IllegalStateException - Thrown if the model has not been compiled and fit.
    • predict

      public linalg.Matrix predict(linalg.Matrix X, linalg.Matrix w)
      Makes a prediction using a model by specifying the parameters of the model. Unlike the other predict method, no model needs to be trained to use this method since the parameters provided define a model.
      Specified by:
      predict in class Model<double[][],​double[]>
      Parameters:
      w - Parameters of the model
      X - Features to make prediction on
      Returns:
      prediction on the features using the given model parameters.
    • getParams

      public linalg.Matrix getParams()
      Gets the parameters of the trained model.
      Specified by:
      getParams in class Model<double[][],​double[]>
      Returns:
      A matrix containing the parameters of the trained model.
    • getLossHist

      public double[] getLossHist()
      Gets the loss history from the optimizer.
      Returns:
      Returns the loss for each iteration of the optimization algorithm in an array. The index of the array corresponds to the iteration the loss was computed for.
    • saveModel

      public void saveModel(String filePath)
      Saves a trained model to the specified file path.
      Specified by:
      saveModel in class Model<double[][],​double[]>
      Parameters:
      filePath - File path, including extension, to save fitted / trained model to.
    • inspect

      public String inspect()
      Forms a string of the important aspects of the model.
      same as toString()
      Specified by:
      inspect in class Model<double[][],​double[]>
      Returns:
      Details of model as string.
    • toString

      public String toString()
      Forms a string of the important aspects of the model.
      Specified by:
      toString in class Model<double[][],​double[]>
      Returns:
      String representation of model.