Autoría | Ultima modificación | Ver Log |
<?phpdeclare(strict_types=1);namespace Phpml\Classification\Linear;use Closure;use Exception;use Phpml\Exception\InvalidArgumentException;use Phpml\Helper\Optimizer\ConjugateGradient;class LogisticRegression extends Adaline{/*** Batch training: Gradient descent algorithm (default)*/public const BATCH_TRAINING = 1;/*** Online training: Stochastic gradient descent learning*/public const ONLINE_TRAINING = 2;/*** Conjugate Batch: Conjugate Gradient algorithm*/public const CONJUGATE_GRAD_TRAINING = 3;/*** Cost function to optimize: 'log' and 'sse' are supported <br>* - 'log' : log likelihood <br>* - 'sse' : sum of squared errors <br>** @var string*/protected $costFunction = 'log';/*** Regularization term: only 'L2' is supported** @var string*/protected $penalty = 'L2';/*** Lambda (λ) parameter of regularization term. If λ is set to 0, then* regularization term is cancelled.** @var float*/protected $lambda = 0.5;/*** Initalize a Logistic Regression classifier with maximum number of iterations* and learning rule to be applied <br>** Maximum number of iterations can be an integer value greater than 0 <br>* If normalizeInputs is set to true, then every input given to the algorithm will be standardized* by use of standard deviation and mean calculation <br>** Cost function can be 'log' for log-likelihood and 'sse' for sum of squared errors <br>** Penalty (Regularization term) can be 'L2' or empty string to cancel penalty term** @throws InvalidArgumentException*/public function __construct(int $maxIterations = 500,bool $normalizeInputs = true,int $trainingType = self::CONJUGATE_GRAD_TRAINING,string $cost = 'log',string $penalty = 'L2') {$trainingTypes = range(self::BATCH_TRAINING, self::CONJUGATE_GRAD_TRAINING);if (!in_array($trainingType, $trainingTypes, true)) {throw new InvalidArgumentException('Logistic regression can only be trained with '.'batch (gradient descent), online (stochastic gradient descent) '.'or conjugate batch (conjugate gradients) algorithms');}if (!in_array($cost, ['log', 'sse'], true)) {throw new InvalidArgumentException("Logistic regression cost function can be one of the following: \n"."'log' for log-likelihood and 'sse' for sum of squared errors");}if ($penalty !== '' && strtoupper($penalty) !== 'L2') {throw new InvalidArgumentException('Logistic regression supports only \'L2\' regularization');}$this->learningRate = 0.001;parent::__construct($this->learningRate, $maxIterations, $normalizeInputs);$this->trainingType = $trainingType;$this->costFunction = $cost;$this->penalty = $penalty;}/*** Sets the learning rate if gradient descent algorithm is* selected for training*/public function setLearningRate(float $learningRate): void{$this->learningRate = $learningRate;}/*** Lambda (λ) parameter of regularization term. If 0 is given,* then the regularization term is cancelled*/public function setLambda(float $lambda): void{$this->lambda = $lambda;}/*** Adapts the weights with respect to given samples and targets* by use of selected solver** @throws \Exception*/protected function runTraining(array $samples, array $targets): void{$callback = $this->getCostFunction();switch ($this->trainingType) {case self::BATCH_TRAINING:$this->runGradientDescent($samples, $targets, $callback, true);return;case self::ONLINE_TRAINING:$this->runGradientDescent($samples, $targets, $callback, false);return;case self::CONJUGATE_GRAD_TRAINING:$this->runConjugateGradient($samples, $targets, $callback);return;default:// Not reachedthrow new Exception(sprintf('Logistic regression has invalid training type: %d.', $this->trainingType));}}/*** Executes Conjugate Gradient method to optimize the weights of the LogReg model*/protected function runConjugateGradient(array $samples, array $targets, Closure $gradientFunc): void{if ($this->optimizer === null) {$this->optimizer = (new ConjugateGradient($this->featureCount))->setMaxIterations($this->maxIterations);}$this->weights = $this->optimizer->runOptimization($samples, $targets, $gradientFunc);$this->costValues = $this->optimizer->getCostValues();}/*** Returns the appropriate callback function for the selected cost function** @throws \Exception*/protected function getCostFunction(): Closure{$penalty = 0;if ($this->penalty === 'L2') {$penalty = $this->lambda;}switch ($this->costFunction) {case 'log':/** Negative of Log-likelihood cost function to be minimized:* J(x) = ∑( - y . log(h(x)) - (1 - y) . log(1 - h(x)))** If regularization term is given, then it will be added to the cost:* for L2 : J(x) = J(x) + λ/m . w** The gradient of the cost function to be used with gradient descent:* ∇J(x) = -(y - h(x)) = (h(x) - y)*/return function ($weights, $sample, $y) use ($penalty): array {$this->weights = $weights;$hX = $this->output($sample);// In cases where $hX = 1 or $hX = 0, the log-likelihood// value will give a NaN, so we fix these valuesif ($hX == 1) {$hX = 1 - 1e-10;}if ($hX == 0) {$hX = 1e-10;}$y = $y < 0 ? 0 : 1;$error = -$y * log($hX) - (1 - $y) * log(1 - $hX);$gradient = $hX - $y;return [$error, $gradient, $penalty];};case 'sse':/** Sum of squared errors or least squared errors cost function:* J(x) = ∑ (y - h(x))^2** If regularization term is given, then it will be added to the cost:* for L2 : J(x) = J(x) + λ/m . w** The gradient of the cost function:* ∇J(x) = -(h(x) - y) . h(x) . (1 - h(x))*/return function ($weights, $sample, $y) use ($penalty): array {$this->weights = $weights;$hX = $this->output($sample);$y = $y < 0 ? 0 : 1;$error = (($y - $hX) ** 2);$gradient = -($y - $hX) * $hX * (1 - $hX);return [$error, $gradient, $penalty];};default:// Not reachedthrow new Exception(sprintf('Logistic regression has invalid cost function: %s.', $this->costFunction));}}/*** Returns the output of the network, a float value between 0.0 and 1.0*/protected function output(array $sample): float{$sum = parent::output($sample);return 1.0 / (1.0 + exp(-$sum));}/*** Returns the class value (either -1 or 1) for the given input*/protected function outputClass(array $sample): int{$output = $this->output($sample);if ($output > 0.5) {return 1;}return -1;}/*** Returns the probability of the sample of belonging to the given label.** The probability is simply taken as the distance of the sample* to the decision plane.** @param mixed $label*/protected function predictProbability(array $sample, $label): float{$sample = $this->checkNormalizedSample($sample);$probability = $this->output($sample);if (array_search($label, $this->labels, true) > 0) {return $probability;}return 1 - $probability;}}