| 1 |
efrain |
1 |
<?php
|
|
|
2 |
|
|
|
3 |
declare(strict_types=1);
|
|
|
4 |
|
|
|
5 |
namespace Phpml\Classification\Ensemble;
|
|
|
6 |
|
|
|
7 |
use Phpml\Classification\Classifier;
|
|
|
8 |
use Phpml\Classification\Linear\DecisionStump;
|
|
|
9 |
use Phpml\Classification\WeightedClassifier;
|
|
|
10 |
use Phpml\Exception\InvalidArgumentException;
|
|
|
11 |
use Phpml\Helper\Predictable;
|
|
|
12 |
use Phpml\Helper\Trainable;
|
|
|
13 |
use Phpml\Math\Statistic\Mean;
|
|
|
14 |
use Phpml\Math\Statistic\StandardDeviation;
|
|
|
15 |
use ReflectionClass;
|
|
|
16 |
|
|
|
17 |
class AdaBoost implements Classifier
|
|
|
18 |
{
|
|
|
19 |
use Predictable;
|
|
|
20 |
use Trainable;
|
|
|
21 |
|
|
|
22 |
/**
|
|
|
23 |
* Actual labels given in the targets array
|
|
|
24 |
*
|
|
|
25 |
* @var array
|
|
|
26 |
*/
|
|
|
27 |
protected $labels = [];
|
|
|
28 |
|
|
|
29 |
/**
|
|
|
30 |
* @var int
|
|
|
31 |
*/
|
|
|
32 |
protected $sampleCount;
|
|
|
33 |
|
|
|
34 |
/**
|
|
|
35 |
* @var int
|
|
|
36 |
*/
|
|
|
37 |
protected $featureCount;
|
|
|
38 |
|
|
|
39 |
/**
|
|
|
40 |
* Number of maximum iterations to be done
|
|
|
41 |
*
|
|
|
42 |
* @var int
|
|
|
43 |
*/
|
|
|
44 |
protected $maxIterations;
|
|
|
45 |
|
|
|
46 |
/**
|
|
|
47 |
* Sample weights
|
|
|
48 |
*
|
|
|
49 |
* @var array
|
|
|
50 |
*/
|
|
|
51 |
protected $weights = [];
|
|
|
52 |
|
|
|
53 |
/**
|
|
|
54 |
* List of selected 'weak' classifiers
|
|
|
55 |
*
|
|
|
56 |
* @var array
|
|
|
57 |
*/
|
|
|
58 |
protected $classifiers = [];
|
|
|
59 |
|
|
|
60 |
/**
|
|
|
61 |
* Base classifier weights
|
|
|
62 |
*
|
|
|
63 |
* @var array
|
|
|
64 |
*/
|
|
|
65 |
protected $alpha = [];
|
|
|
66 |
|
|
|
67 |
/**
|
|
|
68 |
* @var string
|
|
|
69 |
*/
|
|
|
70 |
protected $baseClassifier = DecisionStump::class;
|
|
|
71 |
|
|
|
72 |
/**
|
|
|
73 |
* @var array
|
|
|
74 |
*/
|
|
|
75 |
protected $classifierOptions = [];
|
|
|
76 |
|
|
|
77 |
/**
|
|
|
78 |
* ADAptive BOOSTing (AdaBoost) is an ensemble algorithm to
|
|
|
79 |
* improve classification performance of 'weak' classifiers such as
|
|
|
80 |
* DecisionStump (default base classifier of AdaBoost).
|
|
|
81 |
*/
|
|
|
82 |
public function __construct(int $maxIterations = 50)
|
|
|
83 |
{
|
|
|
84 |
$this->maxIterations = $maxIterations;
|
|
|
85 |
}
|
|
|
86 |
|
|
|
87 |
/**
|
|
|
88 |
* Sets the base classifier that will be used for boosting (default = DecisionStump)
|
|
|
89 |
*/
|
|
|
90 |
public function setBaseClassifier(string $baseClassifier = DecisionStump::class, array $classifierOptions = []): void
|
|
|
91 |
{
|
|
|
92 |
$this->baseClassifier = $baseClassifier;
|
|
|
93 |
$this->classifierOptions = $classifierOptions;
|
|
|
94 |
}
|
|
|
95 |
|
|
|
96 |
/**
|
|
|
97 |
* @throws InvalidArgumentException
|
|
|
98 |
*/
|
|
|
99 |
public function train(array $samples, array $targets): void
|
|
|
100 |
{
|
|
|
101 |
// Initialize usual variables
|
|
|
102 |
$this->labels = array_keys(array_count_values($targets));
|
|
|
103 |
if (count($this->labels) !== 2) {
|
|
|
104 |
throw new InvalidArgumentException('AdaBoost is a binary classifier and can classify between two classes only');
|
|
|
105 |
}
|
|
|
106 |
|
|
|
107 |
// Set all target values to either -1 or 1
|
|
|
108 |
$this->labels = [
|
|
|
109 |
1 => $this->labels[0],
|
|
|
110 |
-1 => $this->labels[1],
|
|
|
111 |
];
|
|
|
112 |
foreach ($targets as $target) {
|
|
|
113 |
$this->targets[] = $target == $this->labels[1] ? 1 : -1;
|
|
|
114 |
}
|
|
|
115 |
|
|
|
116 |
$this->samples = array_merge($this->samples, $samples);
|
|
|
117 |
$this->featureCount = count($samples[0]);
|
|
|
118 |
$this->sampleCount = count($this->samples);
|
|
|
119 |
|
|
|
120 |
// Initialize AdaBoost parameters
|
|
|
121 |
$this->weights = array_fill(0, $this->sampleCount, 1.0 / $this->sampleCount);
|
|
|
122 |
$this->classifiers = [];
|
|
|
123 |
$this->alpha = [];
|
|
|
124 |
|
|
|
125 |
// Execute the algorithm for a maximum number of iterations
|
|
|
126 |
$currIter = 0;
|
|
|
127 |
while ($this->maxIterations > $currIter++) {
|
|
|
128 |
// Determine the best 'weak' classifier based on current weights
|
|
|
129 |
$classifier = $this->getBestClassifier();
|
|
|
130 |
$errorRate = $this->evaluateClassifier($classifier);
|
|
|
131 |
|
|
|
132 |
// Update alpha & weight values at each iteration
|
|
|
133 |
$alpha = $this->calculateAlpha($errorRate);
|
|
|
134 |
$this->updateWeights($classifier, $alpha);
|
|
|
135 |
|
|
|
136 |
$this->classifiers[] = $classifier;
|
|
|
137 |
$this->alpha[] = $alpha;
|
|
|
138 |
}
|
|
|
139 |
}
|
|
|
140 |
|
|
|
141 |
/**
|
|
|
142 |
* @return mixed
|
|
|
143 |
*/
|
|
|
144 |
public function predictSample(array $sample)
|
|
|
145 |
{
|
|
|
146 |
$sum = 0;
|
|
|
147 |
foreach ($this->alpha as $index => $alpha) {
|
|
|
148 |
$h = $this->classifiers[$index]->predict($sample);
|
|
|
149 |
$sum += $h * $alpha;
|
|
|
150 |
}
|
|
|
151 |
|
|
|
152 |
return $this->labels[$sum > 0 ? 1 : -1];
|
|
|
153 |
}
|
|
|
154 |
|
|
|
155 |
/**
|
|
|
156 |
* Returns the classifier with the lowest error rate with the
|
|
|
157 |
* consideration of current sample weights
|
|
|
158 |
*/
|
|
|
159 |
protected function getBestClassifier(): Classifier
|
|
|
160 |
{
|
|
|
161 |
$ref = new ReflectionClass($this->baseClassifier);
|
|
|
162 |
/** @var Classifier $classifier */
|
|
|
163 |
$classifier = count($this->classifierOptions) === 0 ? $ref->newInstance() : $ref->newInstanceArgs($this->classifierOptions);
|
|
|
164 |
|
|
|
165 |
if ($classifier instanceof WeightedClassifier) {
|
|
|
166 |
$classifier->setSampleWeights($this->weights);
|
|
|
167 |
$classifier->train($this->samples, $this->targets);
|
|
|
168 |
} else {
|
|
|
169 |
[$samples, $targets] = $this->resample();
|
|
|
170 |
$classifier->train($samples, $targets);
|
|
|
171 |
}
|
|
|
172 |
|
|
|
173 |
return $classifier;
|
|
|
174 |
}
|
|
|
175 |
|
|
|
176 |
/**
|
|
|
177 |
* Resamples the dataset in accordance with the weights and
|
|
|
178 |
* returns the new dataset
|
|
|
179 |
*/
|
|
|
180 |
protected function resample(): array
|
|
|
181 |
{
|
|
|
182 |
$weights = $this->weights;
|
|
|
183 |
$std = StandardDeviation::population($weights);
|
|
|
184 |
$mean = Mean::arithmetic($weights);
|
|
|
185 |
$min = min($weights);
|
|
|
186 |
$minZ = (int) round(($min - $mean) / $std);
|
|
|
187 |
|
|
|
188 |
$samples = [];
|
|
|
189 |
$targets = [];
|
|
|
190 |
foreach ($weights as $index => $weight) {
|
|
|
191 |
$z = (int) round(($weight - $mean) / $std) - $minZ + 1;
|
|
|
192 |
for ($i = 0; $i < $z; ++$i) {
|
|
|
193 |
if (random_int(0, 1) == 0) {
|
|
|
194 |
continue;
|
|
|
195 |
}
|
|
|
196 |
|
|
|
197 |
$samples[] = $this->samples[$index];
|
|
|
198 |
$targets[] = $this->targets[$index];
|
|
|
199 |
}
|
|
|
200 |
}
|
|
|
201 |
|
|
|
202 |
return [$samples, $targets];
|
|
|
203 |
}
|
|
|
204 |
|
|
|
205 |
/**
|
|
|
206 |
* Evaluates the classifier and returns the classification error rate
|
|
|
207 |
*/
|
|
|
208 |
protected function evaluateClassifier(Classifier $classifier): float
|
|
|
209 |
{
|
|
|
210 |
$total = (float) array_sum($this->weights);
|
|
|
211 |
$wrong = 0;
|
|
|
212 |
foreach ($this->samples as $index => $sample) {
|
|
|
213 |
$predicted = $classifier->predict($sample);
|
|
|
214 |
if ($predicted != $this->targets[$index]) {
|
|
|
215 |
$wrong += $this->weights[$index];
|
|
|
216 |
}
|
|
|
217 |
}
|
|
|
218 |
|
|
|
219 |
return $wrong / $total;
|
|
|
220 |
}
|
|
|
221 |
|
|
|
222 |
/**
|
|
|
223 |
* Calculates alpha of a classifier
|
|
|
224 |
*/
|
|
|
225 |
protected function calculateAlpha(float $errorRate): float
|
|
|
226 |
{
|
|
|
227 |
if ($errorRate == 0) {
|
|
|
228 |
$errorRate = 1e-10;
|
|
|
229 |
}
|
|
|
230 |
|
|
|
231 |
return 0.5 * log((1 - $errorRate) / $errorRate);
|
|
|
232 |
}
|
|
|
233 |
|
|
|
234 |
/**
|
|
|
235 |
* Updates the sample weights
|
|
|
236 |
*/
|
|
|
237 |
protected function updateWeights(Classifier $classifier, float $alpha): void
|
|
|
238 |
{
|
|
|
239 |
$sumOfWeights = array_sum($this->weights);
|
|
|
240 |
$weightsT1 = [];
|
|
|
241 |
foreach ($this->weights as $index => $weight) {
|
|
|
242 |
$desired = $this->targets[$index];
|
|
|
243 |
$output = $classifier->predict($this->samples[$index]);
|
|
|
244 |
|
|
|
245 |
$weight *= exp(-$alpha * $desired * $output) / $sumOfWeights;
|
|
|
246 |
|
|
|
247 |
$weightsT1[] = $weight;
|
|
|
248 |
}
|
|
|
249 |
|
|
|
250 |
$this->weights = $weightsT1;
|
|
|
251 |
}
|
|
|
252 |
}
|