Package com.jml.linear_models
Class LinearRegressionSGD
java.lang.Object
com.jml.core.Model<double[],double[]>
com.jml.linear_models.PolynomialRegression
com.jml.linear_models.LinearRegression
com.jml.linear_models.LinearRegressionSGD
Model for least squares linear regression of one variable by stochastic gradient descent.
LinearRegressionSGD fits a model y = b0 + b1x to the datasets by minimizing the residuals of the sum of squares between the values in the target dataset and the values predicted by the model. This is using stochastic gradient descent.
LinearRegressionSGD fits a model y = b0 + b1x to the datasets by minimizing the residuals of the sum of squares between the values in the target dataset and the values predicted by the model. This is using stochastic gradient descent.
-
Constructor Summary
ConstructorDescriptionCreates aLinearRegressionSGD
model.
This will use default settings for gradient descent:LinearRegressionSGD(double learningRate)
Creates aLinearRegressionSGD
model.LinearRegressionSGD(double learningRate, int maxIterations)
Creates aLinearRegressionSGD
model.LinearRegressionSGD(double learningRate, int maxIterations, double threshold)
Creates aLinearRegressionSGD
model.LinearRegressionSGD(int maxIterations)
Creates aLinearRegressionSGD
model. -
Method Summary
Modifier and TypeMethodDescriptionfit(double[] features, double[] targets)
Fits or trains the model with the given features and targets.double[]
Gets the loss history from training.Methods inherited from class com.jml.linear_models.LinearRegression
inspect, toString
Methods inherited from class com.jml.linear_models.PolynomialRegression
getParams, predict, predict, saveModel
-
Constructor Details
-
LinearRegressionSGD
public LinearRegressionSGD()Creates aLinearRegressionSGD
model.
This will use default settings for gradient descent:Learning Rate: 0.002 Threshold: 0.5e-5 Maximum Iterations: 1000 Scheduler: None
-
LinearRegressionSGD
public LinearRegressionSGD(double learningRate, int maxIterations, double threshold)Creates aLinearRegressionSGD
model. When thefit
method is called,Stochastic Gradient Descent
will use the provided learning rate and will stop if it does not converge within the threshold by the specified number of max iterations.- Parameters:
learningRate
- Learning rate to use duringStochastic Gradient Descent
threshold
- Threshold for early stopping duringStochastic Gradient Descent
. If the loss is less than the specified threshold, gradient descent will stop early.maxIterations
- Maximum number of iterations to run for duringStochastic Gradient Descent
.
-
LinearRegressionSGD
public LinearRegressionSGD(double learningRate, int maxIterations)Creates aLinearRegressionSGD
model. When thefit
method is called,Stochastic Gradient Descent
will use the provided learning rate and will stop if it does not converge by the specified number of max iterations.- Parameters:
learningRate
- Learning rate to use duringStochastic Gradient Descent
.maxIterations
- Maximum number of iterations to run for duringStochastic Gradient Descent
.
-
LinearRegressionSGD
public LinearRegressionSGD(double learningRate)Creates aLinearRegressionSGD
model. When thefit
method is called,Stochastic Gradient Descent
will use the provided learning rate and will stop if it does not converge by the specified number of max iterations.- Parameters:
learningRate
- Learning rate to use duringStochastic Gradient Descent
.
-
LinearRegressionSGD
public LinearRegressionSGD(int maxIterations)Creates aLinearRegressionSGD
model. When thefit
method is called,Stochastic Gradient Descent
will use the provided learning rate and will stop if it does not converge by the specified number of max iterations.- Parameters:
maxIterations
- Maximum number of iterations to run for duringStochastic Gradient Descent
.
-
-
Method Details
-
fit
Fits or trains the model with the given features and targets.- Overrides:
fit
in classLinearRegression
- Parameters:
features
- The features of the training set.targets
- The targets of the training set.- Returns:
- Returns details of the fitting / training process.
- Throws:
IllegalArgumentException
- Can be thrown for the following reasons
- If key, value pairs inargs
are unspecified or invalid arguments.
- If the features and targets are not correctly sized per the specification when the model was compiled.
-
getLossHist
public double[] getLossHist()Gets the loss history from training.- Returns:
- The loss of every iteration stored in a List.
-