Proyectos de Subversion Moodle

Rev

| Ultima modificación | Ver Log |

Rev Autor Línea Nro. Línea
1 efrain 1
<?php
2
 
3
declare(strict_types=1);
4
 
5
namespace Phpml\CrossValidation;
6
 
7
use Phpml\Dataset\ArrayDataset;
8
use Phpml\Dataset\Dataset;
9
 
10
class StratifiedRandomSplit extends RandomSplit
11
{
12
    protected function splitDataset(Dataset $dataset, float $testSize): void
13
    {
14
        $datasets = $this->splitByTarget($dataset);
15
 
16
        foreach ($datasets as $targetSet) {
17
            parent::splitDataset($targetSet, $testSize);
18
        }
19
    }
20
 
21
    /**
22
     * @return Dataset[]
23
     */
24
    private function splitByTarget(Dataset $dataset): array
25
    {
26
        $targets = $dataset->getTargets();
27
        $samples = $dataset->getSamples();
28
 
29
        $uniqueTargets = array_unique($targets);
30
        /** @var array $split */
31
        $split = array_combine($uniqueTargets, array_fill(0, count($uniqueTargets), []));
32
 
33
        foreach ($samples as $key => $sample) {
34
            $split[$targets[$key]][] = $sample;
35
        }
36
 
37
        return $this->createDatasets($uniqueTargets, $split);
38
    }
39
 
40
    private function createDatasets(array $uniqueTargets, array $split): array
41
    {
42
        $datasets = [];
43
        foreach ($uniqueTargets as $target) {
44
            $datasets[$target] = new ArrayDataset($split[$target], array_fill(0, count($split[$target]), $target));
45
        }
46
 
47
        return $datasets;
48
    }
49
}