| 1 |
efrain |
1 |
<?php
|
|
|
2 |
|
|
|
3 |
declare(strict_types=1);
|
|
|
4 |
|
|
|
5 |
namespace Phpml\Preprocessing;
|
|
|
6 |
|
|
|
7 |
use Phpml\Exception\InvalidArgumentException;
|
|
|
8 |
|
|
|
9 |
final class OneHotEncoder implements Preprocessor
|
|
|
10 |
{
|
|
|
11 |
/**
|
|
|
12 |
* @var bool
|
|
|
13 |
*/
|
|
|
14 |
private $ignoreUnknown;
|
|
|
15 |
|
|
|
16 |
/**
|
|
|
17 |
* @var array
|
|
|
18 |
*/
|
|
|
19 |
private $categories = [];
|
|
|
20 |
|
|
|
21 |
public function __construct(bool $ignoreUnknown = false)
|
|
|
22 |
{
|
|
|
23 |
$this->ignoreUnknown = $ignoreUnknown;
|
|
|
24 |
}
|
|
|
25 |
|
|
|
26 |
public function fit(array $samples, ?array $targets = null): void
|
|
|
27 |
{
|
|
|
28 |
foreach (array_keys(array_values(current($samples))) as $column) {
|
|
|
29 |
$this->fitColumn($column, array_values(array_unique(array_column($samples, $column))));
|
|
|
30 |
}
|
|
|
31 |
}
|
|
|
32 |
|
|
|
33 |
public function transform(array &$samples, ?array &$targets = null): void
|
|
|
34 |
{
|
|
|
35 |
foreach ($samples as &$sample) {
|
|
|
36 |
$sample = $this->transformSample(array_values($sample));
|
|
|
37 |
}
|
|
|
38 |
}
|
|
|
39 |
|
|
|
40 |
private function fitColumn(int $column, array $values): void
|
|
|
41 |
{
|
|
|
42 |
$count = count($values);
|
|
|
43 |
foreach ($values as $index => $value) {
|
|
|
44 |
$map = array_fill(0, $count, 0);
|
|
|
45 |
$map[$index] = 1;
|
|
|
46 |
$this->categories[$column][$value] = $map;
|
|
|
47 |
}
|
|
|
48 |
}
|
|
|
49 |
|
|
|
50 |
private function transformSample(array $sample): array
|
|
|
51 |
{
|
|
|
52 |
$encoded = [];
|
|
|
53 |
foreach ($sample as $column => $feature) {
|
|
|
54 |
if (!isset($this->categories[$column][$feature]) && !$this->ignoreUnknown) {
|
|
|
55 |
throw new InvalidArgumentException(sprintf('Missing category "%s" for column %s in trained encoder', $feature, $column));
|
|
|
56 |
}
|
|
|
57 |
|
|
|
58 |
$encoded = array_merge(
|
|
|
59 |
$encoded,
|
|
|
60 |
$this->categories[$column][$feature] ?? array_fill(0, count($this->categories[$column]), 0)
|
|
|
61 |
);
|
|
|
62 |
}
|
|
|
63 |
|
|
|
64 |
return $encoded;
|
|
|
65 |
}
|
|
|
66 |
}
|