1 |
efrain |
1 |
<?php
|
|
|
2 |
|
|
|
3 |
declare(strict_types=1);
|
|
|
4 |
|
|
|
5 |
namespace Phpml\Dataset;
|
|
|
6 |
|
|
|
7 |
use Phpml\Exception\InvalidArgumentException;
|
|
|
8 |
|
|
|
9 |
/**
|
|
|
10 |
* MNIST dataset: http://yann.lecun.com/exdb/mnist/
|
|
|
11 |
* original mnist dataset reader: https://github.com/AndrewCarterUK/mnist-neural-network-plain-php
|
|
|
12 |
*/
|
|
|
13 |
final class MnistDataset extends ArrayDataset
|
|
|
14 |
{
|
|
|
15 |
private const MAGIC_IMAGE = 0x00000803;
|
|
|
16 |
|
|
|
17 |
private const MAGIC_LABEL = 0x00000801;
|
|
|
18 |
|
|
|
19 |
private const IMAGE_ROWS = 28;
|
|
|
20 |
|
|
|
21 |
private const IMAGE_COLS = 28;
|
|
|
22 |
|
|
|
23 |
public function __construct(string $imagePath, string $labelPath)
|
|
|
24 |
{
|
|
|
25 |
$this->samples = $this->readImages($imagePath);
|
|
|
26 |
$this->targets = $this->readLabels($labelPath);
|
|
|
27 |
|
|
|
28 |
if (count($this->samples) !== count($this->targets)) {
|
|
|
29 |
throw new InvalidArgumentException('Must have the same number of images and labels');
|
|
|
30 |
}
|
|
|
31 |
}
|
|
|
32 |
|
|
|
33 |
private function readImages(string $imagePath): array
|
|
|
34 |
{
|
|
|
35 |
$stream = fopen($imagePath, 'rb');
|
|
|
36 |
|
|
|
37 |
if ($stream === false) {
|
|
|
38 |
throw new InvalidArgumentException('Could not open file: '.$imagePath);
|
|
|
39 |
}
|
|
|
40 |
|
|
|
41 |
$images = [];
|
|
|
42 |
|
|
|
43 |
try {
|
|
|
44 |
$header = fread($stream, 16);
|
|
|
45 |
|
|
|
46 |
$fields = unpack('Nmagic/Nsize/Nrows/Ncols', (string) $header);
|
|
|
47 |
|
|
|
48 |
if ($fields['magic'] !== self::MAGIC_IMAGE) {
|
|
|
49 |
throw new InvalidArgumentException('Invalid magic number: '.$imagePath);
|
|
|
50 |
}
|
|
|
51 |
|
|
|
52 |
if ($fields['rows'] != self::IMAGE_ROWS) {
|
|
|
53 |
throw new InvalidArgumentException('Invalid number of image rows: '.$imagePath);
|
|
|
54 |
}
|
|
|
55 |
|
|
|
56 |
if ($fields['cols'] != self::IMAGE_COLS) {
|
|
|
57 |
throw new InvalidArgumentException('Invalid number of image cols: '.$imagePath);
|
|
|
58 |
}
|
|
|
59 |
|
|
|
60 |
for ($i = 0; $i < $fields['size']; $i++) {
|
|
|
61 |
$imageBytes = fread($stream, $fields['rows'] * $fields['cols']);
|
|
|
62 |
|
|
|
63 |
// Convert to float between 0 and 1
|
|
|
64 |
$images[] = array_map(function ($b) {
|
|
|
65 |
return $b / 255;
|
|
|
66 |
}, array_values(unpack('C*', (string) $imageBytes)));
|
|
|
67 |
}
|
|
|
68 |
} finally {
|
|
|
69 |
fclose($stream);
|
|
|
70 |
}
|
|
|
71 |
|
|
|
72 |
return $images;
|
|
|
73 |
}
|
|
|
74 |
|
|
|
75 |
private function readLabels(string $labelPath): array
|
|
|
76 |
{
|
|
|
77 |
$stream = fopen($labelPath, 'rb');
|
|
|
78 |
|
|
|
79 |
if ($stream === false) {
|
|
|
80 |
throw new InvalidArgumentException('Could not open file: '.$labelPath);
|
|
|
81 |
}
|
|
|
82 |
|
|
|
83 |
$labels = [];
|
|
|
84 |
|
|
|
85 |
try {
|
|
|
86 |
$header = fread($stream, 8);
|
|
|
87 |
|
|
|
88 |
$fields = unpack('Nmagic/Nsize', (string) $header);
|
|
|
89 |
|
|
|
90 |
if ($fields['magic'] !== self::MAGIC_LABEL) {
|
|
|
91 |
throw new InvalidArgumentException('Invalid magic number: '.$labelPath);
|
|
|
92 |
}
|
|
|
93 |
|
|
|
94 |
$labels = fread($stream, $fields['size']);
|
|
|
95 |
} finally {
|
|
|
96 |
fclose($stream);
|
|
|
97 |
}
|
|
|
98 |
|
|
|
99 |
return array_values(unpack('C*', (string) $labels));
|
|
|
100 |
}
|
|
|
101 |
}
|