Proyectos de Subversion Moodle

Rev

Autoría | Ultima modificación | Ver Log |

<?php

declare(strict_types=1);

namespace Phpml\Classification\Ensemble;

use Phpml\Classification\Classifier;
use Phpml\Classification\DecisionTree;
use Phpml\Exception\InvalidArgumentException;

class RandomForest extends Bagging
{
    /**
     * @var float|string
     */
    protected $featureSubsetRatio = 'log';

    /**
     * @var array|null
     */
    protected $columnNames;

    /**
     * Initializes RandomForest with the given number of trees. More trees
     * may increase the prediction performance while it will also substantially
     * increase the processing time and the required memory
     */
    public function __construct(int $numClassifier = 50)
    {
        parent::__construct($numClassifier);

        $this->setSubsetRatio(1.0);
    }

    /**
     * This method is used to determine how many of the original columns (features)
     * will be used to construct subsets to train base classifiers.<br>
     *
     * Allowed values: 'sqrt', 'log' or any float number between 0.1 and 1.0 <br>
     *
     * Default value for the ratio is 'log' which results in log(numFeatures, 2) + 1
     * features to be taken into consideration while selecting subspace of features
     *
     * @param mixed $ratio
     */
    public function setFeatureSubsetRatio($ratio): self
    {
        if (!is_string($ratio) && !is_float($ratio)) {
            throw new InvalidArgumentException('Feature subset ratio must be a string or a float');
        }

        if (is_float($ratio) && ($ratio < 0.1 || $ratio > 1.0)) {
            throw new InvalidArgumentException('When a float is given, feature subset ratio should be between 0.1 and 1.0');
        }

        if (is_string($ratio) && $ratio !== 'sqrt' && $ratio !== 'log') {
            throw new InvalidArgumentException("When a string is given, feature subset ratio can only be 'sqrt' or 'log'");
        }

        $this->featureSubsetRatio = $ratio;

        return $this;
    }

    /**
     * RandomForest algorithm is usable *only* with DecisionTree
     *
     * @return $this
     */
    public function setClassifer(string $classifier, array $classifierOptions = [])
    {
        if ($classifier !== DecisionTree::class) {
            throw new InvalidArgumentException('RandomForest can only use DecisionTree as base classifier');
        }

        parent::setClassifer($classifier, $classifierOptions);

        return $this;
    }

    /**
     * This will return an array including an importance value for
     * each column in the given dataset. Importance values for a column
     * is the average importance of that column in all trees in the forest
     */
    public function getFeatureImportances(): array
    {
        // Traverse each tree and sum importance of the columns
        $sum = [];
        foreach ($this->classifiers as $tree) {
            /** @var DecisionTree $tree */
            $importances = $tree->getFeatureImportances();

            foreach ($importances as $column => $importance) {
                if (array_key_exists($column, $sum)) {
                    $sum[$column] += $importance;
                } else {
                    $sum[$column] = $importance;
                }
            }
        }

        // Normalize & sort the importance values
        $total = array_sum($sum);
        array_walk($sum, function (&$importance) use ($total): void {
            $importance /= $total;
        });
        arsort($sum);

        return $sum;
    }

    /**
     * A string array to represent the columns is given. They are useful
     * when trying to print some information about the trees such as feature importances
     *
     * @return $this
     */
    public function setColumnNames(array $names)
    {
        $this->columnNames = $names;

        return $this;
    }

    /**
     * @return DecisionTree
     */
    protected function initSingleClassifier(Classifier $classifier): Classifier
    {
        if (!$classifier instanceof DecisionTree) {
            throw new InvalidArgumentException(
                sprintf('Classifier %s expected, got %s', DecisionTree::class, get_class($classifier))
            );
        }

        if (is_float($this->featureSubsetRatio)) {
            $featureCount = (int) ($this->featureSubsetRatio * $this->featureCount);
        } elseif ($this->featureSubsetRatio === 'sqrt') {
            $featureCount = (int) ($this->featureCount ** .5) + 1;
        } else {
            $featureCount = (int) log($this->featureCount, 2) + 1;
        }

        if ($featureCount >= $this->featureCount) {
            $featureCount = $this->featureCount;
        }

        if ($this->columnNames === null) {
            $this->columnNames = range(0, $this->featureCount - 1);
        }

        return $classifier
            ->setColumnNames($this->columnNames)
            ->setNumFeatures($featureCount);
    }
}