1 |
efrain |
1 |
<?php
|
|
|
2 |
|
|
|
3 |
declare(strict_types=1);
|
|
|
4 |
|
|
|
5 |
namespace Phpml\DimensionReduction;
|
|
|
6 |
|
|
|
7 |
use Phpml\Exception\InvalidArgumentException;
|
|
|
8 |
use Phpml\Exception\InvalidOperationException;
|
|
|
9 |
use Phpml\Math\Matrix;
|
|
|
10 |
|
|
|
11 |
class LDA extends EigenTransformerBase
|
|
|
12 |
{
|
|
|
13 |
/**
|
|
|
14 |
* @var bool
|
|
|
15 |
*/
|
|
|
16 |
public $fit = false;
|
|
|
17 |
|
|
|
18 |
/**
|
|
|
19 |
* @var array
|
|
|
20 |
*/
|
|
|
21 |
public $labels = [];
|
|
|
22 |
|
|
|
23 |
/**
|
|
|
24 |
* @var array
|
|
|
25 |
*/
|
|
|
26 |
public $means = [];
|
|
|
27 |
|
|
|
28 |
/**
|
|
|
29 |
* @var array
|
|
|
30 |
*/
|
|
|
31 |
public $counts = [];
|
|
|
32 |
|
|
|
33 |
/**
|
|
|
34 |
* @var float[]
|
|
|
35 |
*/
|
|
|
36 |
public $overallMean = [];
|
|
|
37 |
|
|
|
38 |
/**
|
|
|
39 |
* Linear Discriminant Analysis (LDA) is used to reduce the dimensionality
|
|
|
40 |
* of the data. Unlike Principal Component Analysis (PCA), it is a supervised
|
|
|
41 |
* technique that requires the class labels in order to fit the data to a
|
|
|
42 |
* lower dimensional space. <br><br>
|
|
|
43 |
* The algorithm can be initialized by speciyfing
|
|
|
44 |
* either with the totalVariance(a value between 0.1 and 0.99)
|
|
|
45 |
* or numFeatures (number of features in the dataset) to be preserved.
|
|
|
46 |
*
|
|
|
47 |
* @param float|null $totalVariance Total explained variance to be preserved
|
|
|
48 |
* @param int|null $numFeatures Number of features to be preserved
|
|
|
49 |
*
|
|
|
50 |
* @throws InvalidArgumentException
|
|
|
51 |
*/
|
|
|
52 |
public function __construct(?float $totalVariance = null, ?int $numFeatures = null)
|
|
|
53 |
{
|
|
|
54 |
if ($totalVariance !== null && ($totalVariance < 0.1 || $totalVariance > 0.99)) {
|
|
|
55 |
throw new InvalidArgumentException('Total variance can be a value between 0.1 and 0.99');
|
|
|
56 |
}
|
|
|
57 |
|
|
|
58 |
if ($numFeatures !== null && $numFeatures <= 0) {
|
|
|
59 |
throw new InvalidArgumentException('Number of features to be preserved should be greater than 0');
|
|
|
60 |
}
|
|
|
61 |
|
|
|
62 |
if (($totalVariance !== null) === ($numFeatures !== null)) {
|
|
|
63 |
throw new InvalidArgumentException('Either totalVariance or numFeatures should be specified in order to run the algorithm');
|
|
|
64 |
}
|
|
|
65 |
|
|
|
66 |
if ($numFeatures !== null) {
|
|
|
67 |
$this->numFeatures = $numFeatures;
|
|
|
68 |
}
|
|
|
69 |
|
|
|
70 |
if ($totalVariance !== null) {
|
|
|
71 |
$this->totalVariance = $totalVariance;
|
|
|
72 |
}
|
|
|
73 |
}
|
|
|
74 |
|
|
|
75 |
/**
|
|
|
76 |
* Trains the algorithm to transform the given data to a lower dimensional space.
|
|
|
77 |
*/
|
|
|
78 |
public function fit(array $data, array $classes): array
|
|
|
79 |
{
|
|
|
80 |
$this->labels = $this->getLabels($classes);
|
|
|
81 |
$this->means = $this->calculateMeans($data, $classes);
|
|
|
82 |
|
|
|
83 |
$sW = $this->calculateClassVar($data, $classes);
|
|
|
84 |
$sB = $this->calculateClassCov();
|
|
|
85 |
|
|
|
86 |
$S = $sW->inverse()->multiply($sB);
|
|
|
87 |
$this->eigenDecomposition($S->toArray());
|
|
|
88 |
|
|
|
89 |
$this->fit = true;
|
|
|
90 |
|
|
|
91 |
return $this->reduce($data);
|
|
|
92 |
}
|
|
|
93 |
|
|
|
94 |
/**
|
|
|
95 |
* Transforms the given sample to a lower dimensional vector by using
|
|
|
96 |
* the eigenVectors obtained in the last run of <code>fit</code>.
|
|
|
97 |
*
|
|
|
98 |
* @throws InvalidOperationException
|
|
|
99 |
*/
|
|
|
100 |
public function transform(array $sample): array
|
|
|
101 |
{
|
|
|
102 |
if (!$this->fit) {
|
|
|
103 |
throw new InvalidOperationException('LDA has not been fitted with respect to original dataset, please run LDA::fit() first');
|
|
|
104 |
}
|
|
|
105 |
|
|
|
106 |
if (!is_array($sample[0])) {
|
|
|
107 |
$sample = [$sample];
|
|
|
108 |
}
|
|
|
109 |
|
|
|
110 |
return $this->reduce($sample);
|
|
|
111 |
}
|
|
|
112 |
|
|
|
113 |
/**
|
|
|
114 |
* Returns unique labels in the dataset
|
|
|
115 |
*/
|
|
|
116 |
protected function getLabels(array $classes): array
|
|
|
117 |
{
|
|
|
118 |
$counts = array_count_values($classes);
|
|
|
119 |
|
|
|
120 |
return array_keys($counts);
|
|
|
121 |
}
|
|
|
122 |
|
|
|
123 |
/**
|
|
|
124 |
* Calculates mean of each column for each class and returns
|
|
|
125 |
* n by m matrix where n is number of labels and m is number of columns
|
|
|
126 |
*/
|
|
|
127 |
protected function calculateMeans(array $data, array $classes): array
|
|
|
128 |
{
|
|
|
129 |
$means = [];
|
|
|
130 |
$counts = [];
|
|
|
131 |
$overallMean = array_fill(0, count($data[0]), 0.0);
|
|
|
132 |
|
|
|
133 |
foreach ($data as $index => $row) {
|
|
|
134 |
$label = array_search($classes[$index], $this->labels, true);
|
|
|
135 |
|
|
|
136 |
foreach ($row as $col => $val) {
|
|
|
137 |
if (!isset($means[$label][$col])) {
|
|
|
138 |
$means[$label][$col] = 0.0;
|
|
|
139 |
}
|
|
|
140 |
|
|
|
141 |
$means[$label][$col] += $val;
|
|
|
142 |
$overallMean[$col] += $val;
|
|
|
143 |
}
|
|
|
144 |
|
|
|
145 |
if (!isset($counts[$label])) {
|
|
|
146 |
$counts[$label] = 0;
|
|
|
147 |
}
|
|
|
148 |
|
|
|
149 |
++$counts[$label];
|
|
|
150 |
}
|
|
|
151 |
|
|
|
152 |
foreach ($means as $index => $row) {
|
|
|
153 |
foreach ($row as $col => $sum) {
|
|
|
154 |
$means[$index][$col] = $sum / $counts[$index];
|
|
|
155 |
}
|
|
|
156 |
}
|
|
|
157 |
|
|
|
158 |
// Calculate overall mean of the dataset for each column
|
|
|
159 |
$numElements = array_sum($counts);
|
|
|
160 |
$map = function ($el) use ($numElements) {
|
|
|
161 |
return $el / $numElements;
|
|
|
162 |
};
|
|
|
163 |
$this->overallMean = array_map($map, $overallMean);
|
|
|
164 |
$this->counts = $counts;
|
|
|
165 |
|
|
|
166 |
return $means;
|
|
|
167 |
}
|
|
|
168 |
|
|
|
169 |
/**
|
|
|
170 |
* Returns in-class scatter matrix for each class, which
|
|
|
171 |
* is a n by m matrix where n is number of classes and
|
|
|
172 |
* m is number of columns
|
|
|
173 |
*/
|
|
|
174 |
protected function calculateClassVar(array $data, array $classes): Matrix
|
|
|
175 |
{
|
|
|
176 |
// s is an n (number of classes) by m (number of column) matrix
|
|
|
177 |
$s = array_fill(0, count($data[0]), array_fill(0, count($data[0]), 0));
|
|
|
178 |
$sW = new Matrix($s, false);
|
|
|
179 |
|
|
|
180 |
foreach ($data as $index => $row) {
|
|
|
181 |
$label = array_search($classes[$index], $this->labels, true);
|
|
|
182 |
$means = $this->means[$label];
|
|
|
183 |
|
|
|
184 |
$row = $this->calculateVar($row, $means);
|
|
|
185 |
|
|
|
186 |
$sW = $sW->add($row);
|
|
|
187 |
}
|
|
|
188 |
|
|
|
189 |
return $sW;
|
|
|
190 |
}
|
|
|
191 |
|
|
|
192 |
/**
|
|
|
193 |
* Returns between-class scatter matrix for each class, which
|
|
|
194 |
* is an n by m matrix where n is number of classes and
|
|
|
195 |
* m is number of columns
|
|
|
196 |
*/
|
|
|
197 |
protected function calculateClassCov(): Matrix
|
|
|
198 |
{
|
|
|
199 |
// s is an n (number of classes) by m (number of column) matrix
|
|
|
200 |
$s = array_fill(0, count($this->overallMean), array_fill(0, count($this->overallMean), 0));
|
|
|
201 |
$sB = new Matrix($s, false);
|
|
|
202 |
|
|
|
203 |
foreach ($this->means as $index => $classMeans) {
|
|
|
204 |
$row = $this->calculateVar($classMeans, $this->overallMean);
|
|
|
205 |
$N = $this->counts[$index];
|
|
|
206 |
$sB = $sB->add($row->multiplyByScalar($N));
|
|
|
207 |
}
|
|
|
208 |
|
|
|
209 |
return $sB;
|
|
|
210 |
}
|
|
|
211 |
|
|
|
212 |
/**
|
|
|
213 |
* Returns the result of the calculation (x - m)T.(x - m)
|
|
|
214 |
*/
|
|
|
215 |
protected function calculateVar(array $row, array $means): Matrix
|
|
|
216 |
{
|
|
|
217 |
$x = new Matrix($row, false);
|
|
|
218 |
$m = new Matrix($means, false);
|
|
|
219 |
$diff = $x->subtract($m);
|
|
|
220 |
|
|
|
221 |
return $diff->transpose()->multiply($diff);
|
|
|
222 |
}
|
|
|
223 |
}
|