Proyectos de Subversion Moodle

Rev

| Ultima modificación | Ver Log |

Rev Autor Línea Nro. Línea
1 efrain 1
<?php
2
 
3
declare(strict_types=1);
4
 
5
namespace Phpml\Dataset;
6
 
7
use Phpml\Exception\InvalidArgumentException;
8
 
9
/**
10
 * MNIST dataset: http://yann.lecun.com/exdb/mnist/
11
 * original mnist dataset reader: https://github.com/AndrewCarterUK/mnist-neural-network-plain-php
12
 */
13
final class MnistDataset extends ArrayDataset
14
{
15
    private const MAGIC_IMAGE = 0x00000803;
16
 
17
    private const MAGIC_LABEL = 0x00000801;
18
 
19
    private const IMAGE_ROWS = 28;
20
 
21
    private const IMAGE_COLS = 28;
22
 
23
    public function __construct(string $imagePath, string $labelPath)
24
    {
25
        $this->samples = $this->readImages($imagePath);
26
        $this->targets = $this->readLabels($labelPath);
27
 
28
        if (count($this->samples) !== count($this->targets)) {
29
            throw new InvalidArgumentException('Must have the same number of images and labels');
30
        }
31
    }
32
 
33
    private function readImages(string $imagePath): array
34
    {
35
        $stream = fopen($imagePath, 'rb');
36
 
37
        if ($stream === false) {
38
            throw new InvalidArgumentException('Could not open file: '.$imagePath);
39
        }
40
 
41
        $images = [];
42
 
43
        try {
44
            $header = fread($stream, 16);
45
 
46
            $fields = unpack('Nmagic/Nsize/Nrows/Ncols', (string) $header);
47
 
48
            if ($fields['magic'] !== self::MAGIC_IMAGE) {
49
                throw new InvalidArgumentException('Invalid magic number: '.$imagePath);
50
            }
51
 
52
            if ($fields['rows'] != self::IMAGE_ROWS) {
53
                throw new InvalidArgumentException('Invalid number of image rows: '.$imagePath);
54
            }
55
 
56
            if ($fields['cols'] != self::IMAGE_COLS) {
57
                throw new InvalidArgumentException('Invalid number of image cols: '.$imagePath);
58
            }
59
 
60
            for ($i = 0; $i < $fields['size']; $i++) {
61
                $imageBytes = fread($stream, $fields['rows'] * $fields['cols']);
62
 
63
                // Convert to float between 0 and 1
64
                $images[] = array_map(function ($b) {
65
                    return $b / 255;
66
                }, array_values(unpack('C*', (string) $imageBytes)));
67
            }
68
        } finally {
69
            fclose($stream);
70
        }
71
 
72
        return $images;
73
    }
74
 
75
    private function readLabels(string $labelPath): array
76
    {
77
        $stream = fopen($labelPath, 'rb');
78
 
79
        if ($stream === false) {
80
            throw new InvalidArgumentException('Could not open file: '.$labelPath);
81
        }
82
 
83
        $labels = [];
84
 
85
        try {
86
            $header = fread($stream, 8);
87
 
88
            $fields = unpack('Nmagic/Nsize', (string) $header);
89
 
90
            if ($fields['magic'] !== self::MAGIC_LABEL) {
91
                throw new InvalidArgumentException('Invalid magic number: '.$labelPath);
92
            }
93
 
94
            $labels = fread($stream, $fields['size']);
95
        } finally {
96
            fclose($stream);
97
        }
98
 
99
        return array_values(unpack('C*', (string) $labels));
100
    }
101
}