Proyectos de Subversion Moodle

Rev

| Ultima modificación | Ver Log |

Rev Autor Línea Nro. Línea
1 efrain 1
<?php
2
 
3
declare(strict_types=1);
4
 
5
namespace Phpml\Classification;
6
 
7
use Phpml\Helper\Predictable;
8
use Phpml\Helper\Trainable;
9
use Phpml\Math\Distance;
10
use Phpml\Math\Distance\Euclidean;
11
 
12
class KNearestNeighbors implements Classifier
13
{
14
    use Trainable;
15
    use Predictable;
16
 
17
    /**
18
     * @var int
19
     */
20
    private $k;
21
 
22
    /**
23
     * @var Distance
24
     */
25
    private $distanceMetric;
26
 
27
    /**
28
     * @param Distance|null $distanceMetric (if null then Euclidean distance as default)
29
     */
30
    public function __construct(int $k = 3, ?Distance $distanceMetric = null)
31
    {
32
        if ($distanceMetric === null) {
33
            $distanceMetric = new Euclidean();
34
        }
35
 
36
        $this->k = $k;
37
        $this->samples = [];
38
        $this->targets = [];
39
        $this->distanceMetric = $distanceMetric;
40
    }
41
 
42
    /**
43
     * @return mixed
44
     */
45
    protected function predictSample(array $sample)
46
    {
47
        $distances = $this->kNeighborsDistances($sample);
48
        $predictions = (array) array_combine(array_values($this->targets), array_fill(0, count($this->targets), 0));
49
 
50
        foreach (array_keys($distances) as $index) {
51
            ++$predictions[$this->targets[$index]];
52
        }
53
 
54
        arsort($predictions);
55
        reset($predictions);
56
 
57
        return key($predictions);
58
    }
59
 
60
    /**
61
     * @throws \Phpml\Exception\InvalidArgumentException
62
     */
63
    private function kNeighborsDistances(array $sample): array
64
    {
65
        $distances = [];
66
 
67
        foreach ($this->samples as $index => $neighbor) {
68
            $distances[$index] = $this->distanceMetric->distance($sample, $neighbor);
69
        }
70
 
71
        asort($distances);
72
 
73
        return array_slice($distances, 0, $this->k, true);
74
    }
75
}