Proyectos de Subversion Moodle

Rev

| Ultima modificación | Ver Log |

Rev Autor Línea Nro. Línea
1 efrain 1
<?php
2
 
3
declare(strict_types=1);
4
 
5
namespace Phpml\Regression;
6
 
7
use Phpml\Helper\Predictable;
8
use Phpml\Math\Matrix;
9
 
10
class LeastSquares implements Regression
11
{
12
    use Predictable;
13
 
14
    /**
15
     * @var array
16
     */
17
    private $samples = [];
18
 
19
    /**
20
     * @var array
21
     */
22
    private $targets = [];
23
 
24
    /**
25
     * @var float
26
     */
27
    private $intercept;
28
 
29
    /**
30
     * @var array
31
     */
32
    private $coefficients = [];
33
 
34
    public function train(array $samples, array $targets): void
35
    {
36
        $this->samples = array_merge($this->samples, $samples);
37
        $this->targets = array_merge($this->targets, $targets);
38
 
39
        $this->computeCoefficients();
40
    }
41
 
42
    /**
43
     * @return mixed
44
     */
45
    public function predictSample(array $sample)
46
    {
47
        $result = $this->intercept;
48
        foreach ($this->coefficients as $index => $coefficient) {
49
            $result += $coefficient * $sample[$index];
50
        }
51
 
52
        return $result;
53
    }
54
 
55
    public function getCoefficients(): array
56
    {
57
        return $this->coefficients;
58
    }
59
 
60
    public function getIntercept(): float
61
    {
62
        return $this->intercept;
63
    }
64
 
65
    /**
66
     * coefficient(b) = (X'X)-1X'Y.
67
     */
68
    private function computeCoefficients(): void
69
    {
70
        $samplesMatrix = $this->getSamplesMatrix();
71
        $targetsMatrix = $this->getTargetsMatrix();
72
 
73
        $ts = $samplesMatrix->transpose()->multiply($samplesMatrix)->inverse();
74
        $tf = $samplesMatrix->transpose()->multiply($targetsMatrix);
75
 
76
        $this->coefficients = $ts->multiply($tf)->getColumnValues(0);
77
        $this->intercept = array_shift($this->coefficients);
78
    }
79
 
80
    /**
81
     * Add one dimension for intercept calculation.
82
     */
83
    private function getSamplesMatrix(): Matrix
84
    {
85
        $samples = [];
86
        foreach ($this->samples as $sample) {
87
            array_unshift($sample, 1);
88
            $samples[] = $sample;
89
        }
90
 
91
        return new Matrix($samples);
92
    }
93
 
94
    private function getTargetsMatrix(): Matrix
95
    {
96
        if (is_array($this->targets[0])) {
97
            return new Matrix($this->targets);
98
        }
99
 
100
        return Matrix::fromFlatArray($this->targets);
101
    }
102
}