1 |
efrain |
1 |
<?php
|
|
|
2 |
|
|
|
3 |
declare(strict_types=1);
|
|
|
4 |
|
|
|
5 |
namespace Phpml\Classification\Linear;
|
|
|
6 |
|
|
|
7 |
use Closure;
|
|
|
8 |
use Exception;
|
|
|
9 |
use Phpml\Exception\InvalidArgumentException;
|
|
|
10 |
use Phpml\Helper\Optimizer\ConjugateGradient;
|
|
|
11 |
|
|
|
12 |
class LogisticRegression extends Adaline
|
|
|
13 |
{
|
|
|
14 |
/**
|
|
|
15 |
* Batch training: Gradient descent algorithm (default)
|
|
|
16 |
*/
|
|
|
17 |
public const BATCH_TRAINING = 1;
|
|
|
18 |
|
|
|
19 |
/**
|
|
|
20 |
* Online training: Stochastic gradient descent learning
|
|
|
21 |
*/
|
|
|
22 |
public const ONLINE_TRAINING = 2;
|
|
|
23 |
|
|
|
24 |
/**
|
|
|
25 |
* Conjugate Batch: Conjugate Gradient algorithm
|
|
|
26 |
*/
|
|
|
27 |
public const CONJUGATE_GRAD_TRAINING = 3;
|
|
|
28 |
|
|
|
29 |
/**
|
|
|
30 |
* Cost function to optimize: 'log' and 'sse' are supported <br>
|
|
|
31 |
* - 'log' : log likelihood <br>
|
|
|
32 |
* - 'sse' : sum of squared errors <br>
|
|
|
33 |
*
|
|
|
34 |
* @var string
|
|
|
35 |
*/
|
|
|
36 |
protected $costFunction = 'log';
|
|
|
37 |
|
|
|
38 |
/**
|
|
|
39 |
* Regularization term: only 'L2' is supported
|
|
|
40 |
*
|
|
|
41 |
* @var string
|
|
|
42 |
*/
|
|
|
43 |
protected $penalty = 'L2';
|
|
|
44 |
|
|
|
45 |
/**
|
|
|
46 |
* Lambda (λ) parameter of regularization term. If λ is set to 0, then
|
|
|
47 |
* regularization term is cancelled.
|
|
|
48 |
*
|
|
|
49 |
* @var float
|
|
|
50 |
*/
|
|
|
51 |
protected $lambda = 0.5;
|
|
|
52 |
|
|
|
53 |
/**
|
|
|
54 |
* Initalize a Logistic Regression classifier with maximum number of iterations
|
|
|
55 |
* and learning rule to be applied <br>
|
|
|
56 |
*
|
|
|
57 |
* Maximum number of iterations can be an integer value greater than 0 <br>
|
|
|
58 |
* If normalizeInputs is set to true, then every input given to the algorithm will be standardized
|
|
|
59 |
* by use of standard deviation and mean calculation <br>
|
|
|
60 |
*
|
|
|
61 |
* Cost function can be 'log' for log-likelihood and 'sse' for sum of squared errors <br>
|
|
|
62 |
*
|
|
|
63 |
* Penalty (Regularization term) can be 'L2' or empty string to cancel penalty term
|
|
|
64 |
*
|
|
|
65 |
* @throws InvalidArgumentException
|
|
|
66 |
*/
|
|
|
67 |
public function __construct(
|
|
|
68 |
int $maxIterations = 500,
|
|
|
69 |
bool $normalizeInputs = true,
|
|
|
70 |
int $trainingType = self::CONJUGATE_GRAD_TRAINING,
|
|
|
71 |
string $cost = 'log',
|
|
|
72 |
string $penalty = 'L2'
|
|
|
73 |
) {
|
|
|
74 |
$trainingTypes = range(self::BATCH_TRAINING, self::CONJUGATE_GRAD_TRAINING);
|
|
|
75 |
if (!in_array($trainingType, $trainingTypes, true)) {
|
|
|
76 |
throw new InvalidArgumentException(
|
|
|
77 |
'Logistic regression can only be trained with '.
|
|
|
78 |
'batch (gradient descent), online (stochastic gradient descent) '.
|
|
|
79 |
'or conjugate batch (conjugate gradients) algorithms'
|
|
|
80 |
);
|
|
|
81 |
}
|
|
|
82 |
|
|
|
83 |
if (!in_array($cost, ['log', 'sse'], true)) {
|
|
|
84 |
throw new InvalidArgumentException(
|
|
|
85 |
"Logistic regression cost function can be one of the following: \n".
|
|
|
86 |
"'log' for log-likelihood and 'sse' for sum of squared errors"
|
|
|
87 |
);
|
|
|
88 |
}
|
|
|
89 |
|
|
|
90 |
if ($penalty !== '' && strtoupper($penalty) !== 'L2') {
|
|
|
91 |
throw new InvalidArgumentException('Logistic regression supports only \'L2\' regularization');
|
|
|
92 |
}
|
|
|
93 |
|
|
|
94 |
$this->learningRate = 0.001;
|
|
|
95 |
|
|
|
96 |
parent::__construct($this->learningRate, $maxIterations, $normalizeInputs);
|
|
|
97 |
|
|
|
98 |
$this->trainingType = $trainingType;
|
|
|
99 |
$this->costFunction = $cost;
|
|
|
100 |
$this->penalty = $penalty;
|
|
|
101 |
}
|
|
|
102 |
|
|
|
103 |
/**
|
|
|
104 |
* Sets the learning rate if gradient descent algorithm is
|
|
|
105 |
* selected for training
|
|
|
106 |
*/
|
|
|
107 |
public function setLearningRate(float $learningRate): void
|
|
|
108 |
{
|
|
|
109 |
$this->learningRate = $learningRate;
|
|
|
110 |
}
|
|
|
111 |
|
|
|
112 |
/**
|
|
|
113 |
* Lambda (λ) parameter of regularization term. If 0 is given,
|
|
|
114 |
* then the regularization term is cancelled
|
|
|
115 |
*/
|
|
|
116 |
public function setLambda(float $lambda): void
|
|
|
117 |
{
|
|
|
118 |
$this->lambda = $lambda;
|
|
|
119 |
}
|
|
|
120 |
|
|
|
121 |
/**
|
|
|
122 |
* Adapts the weights with respect to given samples and targets
|
|
|
123 |
* by use of selected solver
|
|
|
124 |
*
|
|
|
125 |
* @throws \Exception
|
|
|
126 |
*/
|
|
|
127 |
protected function runTraining(array $samples, array $targets): void
|
|
|
128 |
{
|
|
|
129 |
$callback = $this->getCostFunction();
|
|
|
130 |
|
|
|
131 |
switch ($this->trainingType) {
|
|
|
132 |
case self::BATCH_TRAINING:
|
|
|
133 |
$this->runGradientDescent($samples, $targets, $callback, true);
|
|
|
134 |
|
|
|
135 |
return;
|
|
|
136 |
|
|
|
137 |
case self::ONLINE_TRAINING:
|
|
|
138 |
$this->runGradientDescent($samples, $targets, $callback, false);
|
|
|
139 |
|
|
|
140 |
return;
|
|
|
141 |
|
|
|
142 |
case self::CONJUGATE_GRAD_TRAINING:
|
|
|
143 |
$this->runConjugateGradient($samples, $targets, $callback);
|
|
|
144 |
|
|
|
145 |
return;
|
|
|
146 |
|
|
|
147 |
default:
|
|
|
148 |
// Not reached
|
|
|
149 |
throw new Exception(sprintf('Logistic regression has invalid training type: %d.', $this->trainingType));
|
|
|
150 |
}
|
|
|
151 |
}
|
|
|
152 |
|
|
|
153 |
/**
|
|
|
154 |
* Executes Conjugate Gradient method to optimize the weights of the LogReg model
|
|
|
155 |
*/
|
|
|
156 |
protected function runConjugateGradient(array $samples, array $targets, Closure $gradientFunc): void
|
|
|
157 |
{
|
|
|
158 |
if ($this->optimizer === null) {
|
|
|
159 |
$this->optimizer = (new ConjugateGradient($this->featureCount))
|
|
|
160 |
->setMaxIterations($this->maxIterations);
|
|
|
161 |
}
|
|
|
162 |
|
|
|
163 |
$this->weights = $this->optimizer->runOptimization($samples, $targets, $gradientFunc);
|
|
|
164 |
$this->costValues = $this->optimizer->getCostValues();
|
|
|
165 |
}
|
|
|
166 |
|
|
|
167 |
/**
|
|
|
168 |
* Returns the appropriate callback function for the selected cost function
|
|
|
169 |
*
|
|
|
170 |
* @throws \Exception
|
|
|
171 |
*/
|
|
|
172 |
protected function getCostFunction(): Closure
|
|
|
173 |
{
|
|
|
174 |
$penalty = 0;
|
|
|
175 |
if ($this->penalty === 'L2') {
|
|
|
176 |
$penalty = $this->lambda;
|
|
|
177 |
}
|
|
|
178 |
|
|
|
179 |
switch ($this->costFunction) {
|
|
|
180 |
case 'log':
|
|
|
181 |
/*
|
|
|
182 |
* Negative of Log-likelihood cost function to be minimized:
|
|
|
183 |
* J(x) = ∑( - y . log(h(x)) - (1 - y) . log(1 - h(x)))
|
|
|
184 |
*
|
|
|
185 |
* If regularization term is given, then it will be added to the cost:
|
|
|
186 |
* for L2 : J(x) = J(x) + λ/m . w
|
|
|
187 |
*
|
|
|
188 |
* The gradient of the cost function to be used with gradient descent:
|
|
|
189 |
* ∇J(x) = -(y - h(x)) = (h(x) - y)
|
|
|
190 |
*/
|
|
|
191 |
return function ($weights, $sample, $y) use ($penalty): array {
|
|
|
192 |
$this->weights = $weights;
|
|
|
193 |
$hX = $this->output($sample);
|
|
|
194 |
|
|
|
195 |
// In cases where $hX = 1 or $hX = 0, the log-likelihood
|
|
|
196 |
// value will give a NaN, so we fix these values
|
|
|
197 |
if ($hX == 1) {
|
|
|
198 |
$hX = 1 - 1e-10;
|
|
|
199 |
}
|
|
|
200 |
|
|
|
201 |
if ($hX == 0) {
|
|
|
202 |
$hX = 1e-10;
|
|
|
203 |
}
|
|
|
204 |
|
|
|
205 |
$y = $y < 0 ? 0 : 1;
|
|
|
206 |
|
|
|
207 |
$error = -$y * log($hX) - (1 - $y) * log(1 - $hX);
|
|
|
208 |
$gradient = $hX - $y;
|
|
|
209 |
|
|
|
210 |
return [$error, $gradient, $penalty];
|
|
|
211 |
};
|
|
|
212 |
case 'sse':
|
|
|
213 |
/*
|
|
|
214 |
* Sum of squared errors or least squared errors cost function:
|
|
|
215 |
* J(x) = ∑ (y - h(x))^2
|
|
|
216 |
*
|
|
|
217 |
* If regularization term is given, then it will be added to the cost:
|
|
|
218 |
* for L2 : J(x) = J(x) + λ/m . w
|
|
|
219 |
*
|
|
|
220 |
* The gradient of the cost function:
|
|
|
221 |
* ∇J(x) = -(h(x) - y) . h(x) . (1 - h(x))
|
|
|
222 |
*/
|
|
|
223 |
return function ($weights, $sample, $y) use ($penalty): array {
|
|
|
224 |
$this->weights = $weights;
|
|
|
225 |
$hX = $this->output($sample);
|
|
|
226 |
|
|
|
227 |
$y = $y < 0 ? 0 : 1;
|
|
|
228 |
|
|
|
229 |
$error = (($y - $hX) ** 2);
|
|
|
230 |
$gradient = -($y - $hX) * $hX * (1 - $hX);
|
|
|
231 |
|
|
|
232 |
return [$error, $gradient, $penalty];
|
|
|
233 |
};
|
|
|
234 |
default:
|
|
|
235 |
// Not reached
|
|
|
236 |
throw new Exception(sprintf('Logistic regression has invalid cost function: %s.', $this->costFunction));
|
|
|
237 |
}
|
|
|
238 |
}
|
|
|
239 |
|
|
|
240 |
/**
|
|
|
241 |
* Returns the output of the network, a float value between 0.0 and 1.0
|
|
|
242 |
*/
|
|
|
243 |
protected function output(array $sample): float
|
|
|
244 |
{
|
|
|
245 |
$sum = parent::output($sample);
|
|
|
246 |
|
|
|
247 |
return 1.0 / (1.0 + exp(-$sum));
|
|
|
248 |
}
|
|
|
249 |
|
|
|
250 |
/**
|
|
|
251 |
* Returns the class value (either -1 or 1) for the given input
|
|
|
252 |
*/
|
|
|
253 |
protected function outputClass(array $sample): int
|
|
|
254 |
{
|
|
|
255 |
$output = $this->output($sample);
|
|
|
256 |
|
|
|
257 |
if ($output > 0.5) {
|
|
|
258 |
return 1;
|
|
|
259 |
}
|
|
|
260 |
|
|
|
261 |
return -1;
|
|
|
262 |
}
|
|
|
263 |
|
|
|
264 |
/**
|
|
|
265 |
* Returns the probability of the sample of belonging to the given label.
|
|
|
266 |
*
|
|
|
267 |
* The probability is simply taken as the distance of the sample
|
|
|
268 |
* to the decision plane.
|
|
|
269 |
*
|
|
|
270 |
* @param mixed $label
|
|
|
271 |
*/
|
|
|
272 |
protected function predictProbability(array $sample, $label): float
|
|
|
273 |
{
|
|
|
274 |
$sample = $this->checkNormalizedSample($sample);
|
|
|
275 |
$probability = $this->output($sample);
|
|
|
276 |
|
|
|
277 |
if (array_search($label, $this->labels, true) > 0) {
|
|
|
278 |
return $probability;
|
|
|
279 |
}
|
|
|
280 |
|
|
|
281 |
return 1 - $probability;
|
|
|
282 |
}
|
|
|
283 |
}
|