Package com.jml.linear_models
Class LinearRegressionSGD
java.lang.Object
com.jml.core.Model<double[],double[]>
com.jml.linear_models.PolynomialRegression
com.jml.linear_models.LinearRegression
com.jml.linear_models.LinearRegressionSGD
Model for least squares linear regression of one variable by stochastic gradient descent.
LinearRegressionSGD fits a model y = b0 + b1x to the datasets by minimizing the residuals of the sum of squares between the values in the target dataset and the values predicted by the model. This is using stochastic gradient descent.
LinearRegressionSGD fits a model y = b0 + b1x to the datasets by minimizing the residuals of the sum of squares between the values in the target dataset and the values predicted by the model. This is using stochastic gradient descent.
-
Constructor Summary
ConstructorsConstructorDescriptionCreates aLinearRegressionSGDmodel.
This will use default settings for gradient descent:LinearRegressionSGD(double learningRate)Creates aLinearRegressionSGDmodel.LinearRegressionSGD(double learningRate, int maxIterations)Creates aLinearRegressionSGDmodel.LinearRegressionSGD(double learningRate, int maxIterations, double threshold)Creates aLinearRegressionSGDmodel.LinearRegressionSGD(int maxIterations)Creates aLinearRegressionSGDmodel. -
Method Summary
Modifier and TypeMethodDescriptionfit(double[] features, double[] targets)Fits or trains the model with the given features and targets.double[]Gets the loss history from training.Methods inherited from class com.jml.linear_models.LinearRegression
inspect, toStringMethods inherited from class com.jml.linear_models.PolynomialRegression
getParams, predict, predict, saveModel
-
Constructor Details
-
LinearRegressionSGD
public LinearRegressionSGD()Creates aLinearRegressionSGDmodel.
This will use default settings for gradient descent:Learning Rate: 0.002 Threshold: 0.5e-5 Maximum Iterations: 1000 Scheduler: None -
LinearRegressionSGD
public LinearRegressionSGD(double learningRate, int maxIterations, double threshold)Creates aLinearRegressionSGDmodel. When thefitmethod is called,Stochastic Gradient Descentwill use the provided learning rate and will stop if it does not converge within the threshold by the specified number of max iterations.- Parameters:
learningRate- Learning rate to use duringStochastic Gradient Descentthreshold- Threshold for early stopping duringStochastic Gradient Descent. If the loss is less than the specified threshold, gradient descent will stop early.maxIterations- Maximum number of iterations to run for duringStochastic Gradient Descent.
-
LinearRegressionSGD
public LinearRegressionSGD(double learningRate, int maxIterations)Creates aLinearRegressionSGDmodel. When thefitmethod is called,Stochastic Gradient Descentwill use the provided learning rate and will stop if it does not converge by the specified number of max iterations.- Parameters:
learningRate- Learning rate to use duringStochastic Gradient Descent.maxIterations- Maximum number of iterations to run for duringStochastic Gradient Descent.
-
LinearRegressionSGD
public LinearRegressionSGD(double learningRate)Creates aLinearRegressionSGDmodel. When thefitmethod is called,Stochastic Gradient Descentwill use the provided learning rate and will stop if it does not converge by the specified number of max iterations.- Parameters:
learningRate- Learning rate to use duringStochastic Gradient Descent.
-
LinearRegressionSGD
public LinearRegressionSGD(int maxIterations)Creates aLinearRegressionSGDmodel. When thefitmethod is called,Stochastic Gradient Descentwill use the provided learning rate and will stop if it does not converge by the specified number of max iterations.- Parameters:
maxIterations- Maximum number of iterations to run for duringStochastic Gradient Descent.
-
-
Method Details
-
fit
Fits or trains the model with the given features and targets.- Overrides:
fitin classLinearRegression- Parameters:
features- The features of the training set.targets- The targets of the training set.- Returns:
- Returns details of the fitting / training process.
- Throws:
IllegalArgumentException- Can be thrown for the following reasons
- If key, value pairs inargsare unspecified or invalid arguments.
- If the features and targets are not correctly sized per the specification when the model was compiled.
-
getLossHist
public double[] getLossHist()Gets the loss history from training.- Returns:
- The loss of every iteration stored in a List.
-