| 1 |
efrain |
1 |
<?php
|
|
|
2 |
|
|
|
3 |
declare(strict_types=1);
|
|
|
4 |
|
|
|
5 |
namespace Phpml\Metric;
|
|
|
6 |
|
|
|
7 |
use Phpml\Exception\InvalidArgumentException;
|
|
|
8 |
|
|
|
9 |
class ClassificationReport
|
|
|
10 |
{
|
|
|
11 |
public const MICRO_AVERAGE = 1;
|
|
|
12 |
|
|
|
13 |
public const MACRO_AVERAGE = 2;
|
|
|
14 |
|
|
|
15 |
public const WEIGHTED_AVERAGE = 3;
|
|
|
16 |
|
|
|
17 |
/**
|
|
|
18 |
* @var array
|
|
|
19 |
*/
|
|
|
20 |
private $truePositive = [];
|
|
|
21 |
|
|
|
22 |
/**
|
|
|
23 |
* @var array
|
|
|
24 |
*/
|
|
|
25 |
private $falsePositive = [];
|
|
|
26 |
|
|
|
27 |
/**
|
|
|
28 |
* @var array
|
|
|
29 |
*/
|
|
|
30 |
private $falseNegative = [];
|
|
|
31 |
|
|
|
32 |
/**
|
|
|
33 |
* @var array
|
|
|
34 |
*/
|
|
|
35 |
private $support = [];
|
|
|
36 |
|
|
|
37 |
/**
|
|
|
38 |
* @var array
|
|
|
39 |
*/
|
|
|
40 |
private $precision = [];
|
|
|
41 |
|
|
|
42 |
/**
|
|
|
43 |
* @var array
|
|
|
44 |
*/
|
|
|
45 |
private $recall = [];
|
|
|
46 |
|
|
|
47 |
/**
|
|
|
48 |
* @var array
|
|
|
49 |
*/
|
|
|
50 |
private $f1score = [];
|
|
|
51 |
|
|
|
52 |
/**
|
|
|
53 |
* @var array
|
|
|
54 |
*/
|
|
|
55 |
private $average = [];
|
|
|
56 |
|
|
|
57 |
public function __construct(array $actualLabels, array $predictedLabels, int $average = self::MACRO_AVERAGE)
|
|
|
58 |
{
|
|
|
59 |
$averagingMethods = range(self::MICRO_AVERAGE, self::WEIGHTED_AVERAGE);
|
|
|
60 |
if (!in_array($average, $averagingMethods, true)) {
|
|
|
61 |
throw new InvalidArgumentException('Averaging method must be MICRO_AVERAGE, MACRO_AVERAGE or WEIGHTED_AVERAGE');
|
|
|
62 |
}
|
|
|
63 |
|
|
|
64 |
$this->aggregateClassificationResults($actualLabels, $predictedLabels);
|
|
|
65 |
$this->computeMetrics();
|
|
|
66 |
$this->computeAverage($average);
|
|
|
67 |
}
|
|
|
68 |
|
|
|
69 |
public function getPrecision(): array
|
|
|
70 |
{
|
|
|
71 |
return $this->precision;
|
|
|
72 |
}
|
|
|
73 |
|
|
|
74 |
public function getRecall(): array
|
|
|
75 |
{
|
|
|
76 |
return $this->recall;
|
|
|
77 |
}
|
|
|
78 |
|
|
|
79 |
public function getF1score(): array
|
|
|
80 |
{
|
|
|
81 |
return $this->f1score;
|
|
|
82 |
}
|
|
|
83 |
|
|
|
84 |
public function getSupport(): array
|
|
|
85 |
{
|
|
|
86 |
return $this->support;
|
|
|
87 |
}
|
|
|
88 |
|
|
|
89 |
public function getAverage(): array
|
|
|
90 |
{
|
|
|
91 |
return $this->average;
|
|
|
92 |
}
|
|
|
93 |
|
|
|
94 |
private function aggregateClassificationResults(array $actualLabels, array $predictedLabels): void
|
|
|
95 |
{
|
|
|
96 |
$truePositive = $falsePositive = $falseNegative = $support = self::getLabelIndexedArray($actualLabels, $predictedLabels);
|
|
|
97 |
|
|
|
98 |
foreach ($actualLabels as $index => $actual) {
|
|
|
99 |
$predicted = $predictedLabels[$index];
|
|
|
100 |
++$support[$actual];
|
|
|
101 |
|
|
|
102 |
if ($actual === $predicted) {
|
|
|
103 |
++$truePositive[$actual];
|
|
|
104 |
} else {
|
|
|
105 |
++$falsePositive[$predicted];
|
|
|
106 |
++$falseNegative[$actual];
|
|
|
107 |
}
|
|
|
108 |
}
|
|
|
109 |
|
|
|
110 |
$this->truePositive = $truePositive;
|
|
|
111 |
$this->falsePositive = $falsePositive;
|
|
|
112 |
$this->falseNegative = $falseNegative;
|
|
|
113 |
$this->support = $support;
|
|
|
114 |
}
|
|
|
115 |
|
|
|
116 |
private function computeMetrics(): void
|
|
|
117 |
{
|
|
|
118 |
foreach ($this->truePositive as $label => $tp) {
|
|
|
119 |
$this->precision[$label] = $this->computePrecision($tp, $this->falsePositive[$label]);
|
|
|
120 |
$this->recall[$label] = $this->computeRecall($tp, $this->falseNegative[$label]);
|
|
|
121 |
$this->f1score[$label] = $this->computeF1Score((float) $this->precision[$label], (float) $this->recall[$label]);
|
|
|
122 |
}
|
|
|
123 |
}
|
|
|
124 |
|
|
|
125 |
private function computeAverage(int $average): void
|
|
|
126 |
{
|
|
|
127 |
switch ($average) {
|
|
|
128 |
case self::MICRO_AVERAGE:
|
|
|
129 |
$this->computeMicroAverage();
|
|
|
130 |
|
|
|
131 |
return;
|
|
|
132 |
case self::MACRO_AVERAGE:
|
|
|
133 |
$this->computeMacroAverage();
|
|
|
134 |
|
|
|
135 |
return;
|
|
|
136 |
case self::WEIGHTED_AVERAGE:
|
|
|
137 |
$this->computeWeightedAverage();
|
|
|
138 |
|
|
|
139 |
return;
|
|
|
140 |
}
|
|
|
141 |
}
|
|
|
142 |
|
|
|
143 |
private function computeMicroAverage(): void
|
|
|
144 |
{
|
|
|
145 |
$truePositive = (int) array_sum($this->truePositive);
|
|
|
146 |
$falsePositive = (int) array_sum($this->falsePositive);
|
|
|
147 |
$falseNegative = (int) array_sum($this->falseNegative);
|
|
|
148 |
|
|
|
149 |
$precision = $this->computePrecision($truePositive, $falsePositive);
|
|
|
150 |
$recall = $this->computeRecall($truePositive, $falseNegative);
|
|
|
151 |
$f1score = $this->computeF1Score($precision, $recall);
|
|
|
152 |
|
|
|
153 |
$this->average = compact('precision', 'recall', 'f1score');
|
|
|
154 |
}
|
|
|
155 |
|
|
|
156 |
private function computeMacroAverage(): void
|
|
|
157 |
{
|
|
|
158 |
foreach (['precision', 'recall', 'f1score'] as $metric) {
|
|
|
159 |
$values = $this->{$metric};
|
|
|
160 |
if (count($values) == 0) {
|
|
|
161 |
$this->average[$metric] = 0.0;
|
|
|
162 |
|
|
|
163 |
continue;
|
|
|
164 |
}
|
|
|
165 |
|
|
|
166 |
$this->average[$metric] = array_sum($values) / count($values);
|
|
|
167 |
}
|
|
|
168 |
}
|
|
|
169 |
|
|
|
170 |
private function computeWeightedAverage(): void
|
|
|
171 |
{
|
|
|
172 |
foreach (['precision', 'recall', 'f1score'] as $metric) {
|
|
|
173 |
$values = $this->{$metric};
|
|
|
174 |
if (count($values) == 0) {
|
|
|
175 |
$this->average[$metric] = 0.0;
|
|
|
176 |
|
|
|
177 |
continue;
|
|
|
178 |
}
|
|
|
179 |
|
|
|
180 |
$sum = 0;
|
|
|
181 |
foreach ($values as $i => $value) {
|
|
|
182 |
$sum += $value * $this->support[$i];
|
|
|
183 |
}
|
|
|
184 |
|
|
|
185 |
$this->average[$metric] = $sum / array_sum($this->support);
|
|
|
186 |
}
|
|
|
187 |
}
|
|
|
188 |
|
|
|
189 |
private function computePrecision(int $truePositive, int $falsePositive): float
|
|
|
190 |
{
|
|
|
191 |
$divider = $truePositive + $falsePositive;
|
|
|
192 |
if ($divider == 0) {
|
|
|
193 |
return 0.0;
|
|
|
194 |
}
|
|
|
195 |
|
|
|
196 |
return $truePositive / $divider;
|
|
|
197 |
}
|
|
|
198 |
|
|
|
199 |
private function computeRecall(int $truePositive, int $falseNegative): float
|
|
|
200 |
{
|
|
|
201 |
$divider = $truePositive + $falseNegative;
|
|
|
202 |
if ($divider == 0) {
|
|
|
203 |
return 0.0;
|
|
|
204 |
}
|
|
|
205 |
|
|
|
206 |
return $truePositive / $divider;
|
|
|
207 |
}
|
|
|
208 |
|
|
|
209 |
private function computeF1Score(float $precision, float $recall): float
|
|
|
210 |
{
|
|
|
211 |
$divider = $precision + $recall;
|
|
|
212 |
if ($divider == 0) {
|
|
|
213 |
return 0.0;
|
|
|
214 |
}
|
|
|
215 |
|
|
|
216 |
return 2.0 * (($precision * $recall) / $divider);
|
|
|
217 |
}
|
|
|
218 |
|
|
|
219 |
private static function getLabelIndexedArray(array $actualLabels, array $predictedLabels): array
|
|
|
220 |
{
|
|
|
221 |
$labels = array_values(array_unique(array_merge($actualLabels, $predictedLabels)));
|
|
|
222 |
sort($labels);
|
|
|
223 |
|
|
|
224 |
return (array) array_combine($labels, array_fill(0, count($labels), 0));
|
|
|
225 |
}
|
|
|
226 |
}
|