Class LogisticRegression


public class LogisticRegression extends Model<double[][],​double[]>
A logistic regression model. Supports binary classification for multiple features.
Fits a logistic function f(x)=1/[ 1+e-w^Tx ] to a dataset by minimizing the binary cross-entropy function.
  • Constructor Summary

    Creates a logistic regression model.
    LogisticRegression​(double learningRate)
    Creates a logistic regression model.
    LogisticRegression​(double learningRate, int maxIterations)
    Creates a logistic regression model.
    LogisticRegression​(double learningRate, int maxIterations, double threshold)
    Creates a logistic regression model.
  • Method Summary

    Modifier and Type
    fit​(double[][] features, double[] targets)
    Fits or trains the model with the given features and targets.
    Gets the loss history from the optimizer.
    Gets the parameters of the trained model.
    Forms a string of the important aspects of the model.
    same as toString()
    predict​(double[][] features)
    Uses fitted/trained model to make prediction on single feature.
    predict​(linalg.Matrix X, linalg.Matrix w)
    Makes a prediction using a model by specifying the parameters of the model.
    saveModel​(String filePath)
    Saves a trained model to the specified file path.
    Forms a string of the important aspects of the model.

    Methods inherited from class com.jml.core.Model


    Methods inherited from class java.lang.Object

    equals, getClass, hashCode, notify, notifyAll, wait, wait, wait
  • Constructor Details

    • LogisticRegression

      public LogisticRegression()
      Creates a logistic regression model. The model will be fit using a stochastic gradient descent optimizer with specified learning rate. Defaults to a learning rate of 0.002, 1000 max iterations and a threshold of 0.5e-5.
    • LogisticRegression

      public LogisticRegression(double learningRate, int maxIterations, double threshold)
      Creates a logistic regression model. The model will be fit using a stochastic gradient descent optimizer with specified learning rate, max iterations, and threshold.
      learningRate - Learning rate to use during optimization.
      maxIterations - Maximum iterations to run optimizer for.
      threshold - Threshold for stopping the optimizer. If the loss becomes less than this value, the optimizer will stop early.
    • LogisticRegression

      public LogisticRegression(double learningRate, int maxIterations)
      Creates a logistic regression model. The model will be fit using a stochastic gradient descent optimizer with specified learning rate, max iterations. Defaults to a threshold of 0.5e-5.
      learningRate - Learning rate to use during optimization.
      maxIterations - Maximum iterations to run optimizer for.
    • LogisticRegression

      public LogisticRegression(double learningRate)
      Creates a logistic regression model. The model will be fit using a stochastic gradient descent optimizer with specified learning rate. Defaults to 1000 max iterations and a threshold of 0.5e-5.
      learningRate - Learning rate to use during optimization.
  • Method Details

    • fit

      public Model<double[][],​double[]> fit(double[][] features, double[] targets)
      Fits or trains the model with the given features and targets.
      Specified by:
      fit in class Model<double[][],​double[]>
      features - The features of the training set.
      targets - The targets of the training set.
      This. i.e. the trained model.
      IllegalArgumentException - Thrown if the features and targets are not correctly sized per the specification when the model was compiled.
    • predict

      public double[] predict(double[][] features)
      Uses fitted/trained model to make prediction on single feature.
      Specified by:
      predict in class Model<double[][],​double[]>
      features - The features to make predictions on.
      The models predicted labels.
      IllegalArgumentException - Thrown if the features are not correctly sized per the specification when the model was compiled.
      IllegalStateException - Thrown if the model has not been compiled and fit.
    • predict

      public linalg.Matrix predict(linalg.Matrix X, linalg.Matrix w)
      Makes a prediction using a model by specifying the parameters of the model. Unlike the other predict method, no model needs to be trained to use this method since the parameters provided define a model.
      Specified by:
      predict in class Model<double[][],​double[]>
      w - Parameters of the model
      X - Features to make prediction on
      prediction on the features using the given model parameters.
    • getParams

      public linalg.Matrix getParams()
      Gets the parameters of the trained model.
      Specified by:
      getParams in class Model<double[][],​double[]>
      A matrix containing the parameters of the trained model.
    • getLossHist

      public double[] getLossHist()
      Gets the loss history from the optimizer.
      Returns the loss for each iteration of the optimization algorithm in an array. The index of the array corresponds to the iteration the loss was computed for.
    • saveModel

      public void saveModel(String filePath)
      Saves a trained model to the specified file path.
      Specified by:
      saveModel in class Model<double[][],​double[]>
      filePath - File path, including extension, to save fitted / trained model to.
    • inspect

      public String inspect()
      Forms a string of the important aspects of the model.
      same as toString()
      Specified by:
      inspect in class Model<double[][],​double[]>
      Details of model as string.
    • toString

      public String toString()
      Forms a string of the important aspects of the model.
      Specified by:
      toString in class Model<double[][],​double[]>
      String representation of model.