AutorÃa | Ultima modificación | Ver Log |
<?phpdeclare(strict_types=1);namespace Phpml\Classification;use Phpml\Classification\DecisionTree\DecisionTreeLeaf;use Phpml\Exception\InvalidArgumentException;use Phpml\Helper\Predictable;use Phpml\Helper\Trainable;use Phpml\Math\Statistic\Mean;class DecisionTree implements Classifier{use Trainable;use Predictable;public const CONTINUOUS = 1;public const NOMINAL = 2;/*** @var int*/public $actualDepth = 0;/*** @var array*/protected $columnTypes = [];/*** @var DecisionTreeLeaf*/protected $tree;/*** @var int*/protected $maxDepth;/*** @var array*/private $labels = [];/*** @var int*/private $featureCount = 0;/*** @var int*/private $numUsableFeatures = 0;/*** @var array*/private $selectedFeatures = [];/*** @var array|null*/private $featureImportances;/*** @var array*/private $columnNames = [];public function __construct(int $maxDepth = 10){$this->maxDepth = $maxDepth;}public function train(array $samples, array $targets): void{$this->samples = array_merge($this->samples, $samples);$this->targets = array_merge($this->targets, $targets);$this->featureCount = count($this->samples[0]);$this->columnTypes = self::getColumnTypes($this->samples);$this->labels = array_keys(array_count_values($this->targets));$this->tree = $this->getSplitLeaf(range(0, count($this->samples) - 1));// Each time the tree is trained, feature importances are reset so that// we will have to compute it again depending on the new data$this->featureImportances = null;// If column names are given or computed before, then there is no// need to init it and accidentally remove the previous given namesif ($this->columnNames === []) {$this->columnNames = range(0, $this->featureCount - 1);} elseif (count($this->columnNames) > $this->featureCount) {$this->columnNames = array_slice($this->columnNames, 0, $this->featureCount);} elseif (count($this->columnNames) < $this->featureCount) {$this->columnNames = array_merge($this->columnNames,range(count($this->columnNames), $this->featureCount - 1));}}public static function getColumnTypes(array $samples): array{$types = [];$featureCount = count($samples[0]);for ($i = 0; $i < $featureCount; ++$i) {$values = array_column($samples, $i);$isCategorical = self::isCategoricalColumn($values);$types[] = $isCategorical ? self::NOMINAL : self::CONTINUOUS;}return $types;}/*** @param mixed $baseValue*/public function getGiniIndex($baseValue, array $colValues, array $targets): float{$countMatrix = [];foreach ($this->labels as $label) {$countMatrix[$label] = [0, 0];}foreach ($colValues as $index => $value) {$label = $targets[$index];$rowIndex = $value === $baseValue ? 0 : 1;++$countMatrix[$label][$rowIndex];}$giniParts = [0, 0];for ($i = 0; $i <= 1; ++$i) {$part = 0;$sum = array_sum(array_column($countMatrix, $i));if ($sum > 0) {foreach ($this->labels as $label) {$part += ($countMatrix[$label][$i] / (float) $sum) ** 2;}}$giniParts[$i] = (1 - $part) * $sum;}return array_sum($giniParts) / count($colValues);}/*** This method is used to set number of columns to be used* when deciding a split at an internal node of the tree. <br>* If the value is given 0, then all features are used (default behaviour),* otherwise the given value will be used as a maximum for number of columns* randomly selected for each split operation.** @return $this** @throws InvalidArgumentException*/public function setNumFeatures(int $numFeatures){if ($numFeatures < 0) {throw new InvalidArgumentException('Selected column count should be greater or equal to zero');}$this->numUsableFeatures = $numFeatures;return $this;}/*** A string array to represent columns. Useful when HTML output or* column importances are desired to be inspected.** @return $this** @throws InvalidArgumentException*/public function setColumnNames(array $names){if ($this->featureCount !== 0 && count($names) !== $this->featureCount) {throw new InvalidArgumentException(sprintf('Length of the given array should be equal to feature count %s', $this->featureCount));}$this->columnNames = $names;return $this;}public function getHtml(): string{return $this->tree->getHTML($this->columnNames);}/*** This will return an array including an importance value for* each column in the given dataset. The importance values are* normalized and their total makes 1.<br/>*/public function getFeatureImportances(): array{if ($this->featureImportances !== null) {return $this->featureImportances;}$sampleCount = count($this->samples);$this->featureImportances = [];foreach ($this->columnNames as $column => $columnName) {$nodes = $this->getSplitNodesByColumn($column, $this->tree);$importance = 0;foreach ($nodes as $node) {$importance += $node->getNodeImpurityDecrease($sampleCount);}$this->featureImportances[$columnName] = $importance;}// Normalize & sort the importances$total = array_sum($this->featureImportances);if ($total > 0) {array_walk($this->featureImportances, function (&$importance) use ($total): void {$importance /= $total;});arsort($this->featureImportances);}return $this->featureImportances;}protected function getSplitLeaf(array $records, int $depth = 0): DecisionTreeLeaf{$split = $this->getBestSplit($records);$split->level = $depth;if ($this->actualDepth < $depth) {$this->actualDepth = $depth;}// Traverse all records to see if all records belong to the same class,// otherwise group the records so that we can classify the leaf// in case maximum depth is reached$leftRecords = [];$rightRecords = [];$remainingTargets = [];$prevRecord = null;$allSame = true;foreach ($records as $recordNo) {// Check if the previous record is the same with the current one$record = $this->samples[$recordNo];if ($prevRecord !== null && $prevRecord != $record) {$allSame = false;}$prevRecord = $record;// According to the split criteron, this record will// belong to either left or the right side in the next splitif ($split->evaluate($record)) {$leftRecords[] = $recordNo;} else {$rightRecords[] = $recordNo;}// Group remaining targets$target = $this->targets[$recordNo];if (!array_key_exists($target, $remainingTargets)) {$remainingTargets[$target] = 1;} else {++$remainingTargets[$target];}}if ($allSame || $depth >= $this->maxDepth || count($remainingTargets) === 1) {$split->isTerminal = true;arsort($remainingTargets);$split->classValue = (string) key($remainingTargets);} else {if (isset($leftRecords[0])) {$split->leftLeaf = $this->getSplitLeaf($leftRecords, $depth + 1);}if (isset($rightRecords[0])) {$split->rightLeaf = $this->getSplitLeaf($rightRecords, $depth + 1);}}return $split;}protected function getBestSplit(array $records): DecisionTreeLeaf{$targets = array_intersect_key($this->targets, array_flip($records));$samples = (array) array_combine($records,$this->preprocess(array_intersect_key($this->samples, array_flip($records))));$bestGiniVal = 1;$bestSplit = null;$features = $this->getSelectedFeatures();foreach ($features as $i) {$colValues = [];foreach ($samples as $index => $row) {$colValues[$index] = $row[$i];}$counts = array_count_values($colValues);arsort($counts);$baseValue = key($counts);if ($baseValue === null) {continue;}$gini = $this->getGiniIndex($baseValue, $colValues, $targets);if ($bestSplit === null || $bestGiniVal > $gini) {$split = new DecisionTreeLeaf();$split->value = $baseValue;$split->giniIndex = $gini;$split->columnIndex = $i;$split->isContinuous = $this->columnTypes[$i] === self::CONTINUOUS;$split->records = $records;// If a numeric column is to be selected, then// the original numeric value and the selected operator// will also be saved into the leaf for future accessif ($this->columnTypes[$i] === self::CONTINUOUS) {$matches = [];preg_match("/^([<>=]{1,2})\s*(.*)/", (string) $split->value, $matches);$split->operator = $matches[1];$split->numericValue = (float) $matches[2];}$bestSplit = $split;$bestGiniVal = $gini;}}return $bestSplit;}/*** Returns available features/columns to the tree for the decision making* process. <br>** If a number is given with setNumFeatures() method, then a random selection* of features up to this number is returned. <br>** If some features are manually selected by use of setSelectedFeatures(),* then only these features are returned <br>** If any of above methods were not called beforehand, then all features* are returned by default.*/protected function getSelectedFeatures(): array{$allFeatures = range(0, $this->featureCount - 1);if ($this->numUsableFeatures === 0 && count($this->selectedFeatures) === 0) {return $allFeatures;}if (count($this->selectedFeatures) > 0) {return $this->selectedFeatures;}$numFeatures = $this->numUsableFeatures;if ($numFeatures > $this->featureCount) {$numFeatures = $this->featureCount;}shuffle($allFeatures);$selectedFeatures = array_slice($allFeatures, 0, $numFeatures);sort($selectedFeatures);return $selectedFeatures;}protected function preprocess(array $samples): array{// Detect and convert continuous data column values into// discrete values by using the median as a threshold value$columns = [];for ($i = 0; $i < $this->featureCount; ++$i) {$values = array_column($samples, $i);if ($this->columnTypes[$i] == self::CONTINUOUS) {$median = Mean::median($values);foreach ($values as &$value) {if ($value <= $median) {$value = "<= {$median}";} else {$value = "> {$median}";}}}$columns[] = $values;}// Below method is a strange yet very simple & efficient method// to get the transpose of a 2D arrayreturn array_map(null, ...$columns);}protected static function isCategoricalColumn(array $columnValues): bool{$count = count($columnValues);// There are two main indicators that *may* show whether a// column is composed of discrete set of values:// 1- Column may contain string values and non-float values// 2- Number of unique values in the column is only a small fraction of// all values in that column (Lower than or equal to %20 of all values)$numericValues = array_filter($columnValues, 'is_numeric');$floatValues = array_filter($columnValues, 'is_float');if (count($floatValues) > 0) {return false;}if (count($numericValues) !== $count) {return true;}$distinctValues = array_count_values($columnValues);return count($distinctValues) <= $count / 5;}/*** Used to set predefined features to consider while deciding which column to use for a split*/protected function setSelectedFeatures(array $selectedFeatures): void{$this->selectedFeatures = $selectedFeatures;}/*** Collects and returns an array of internal nodes that use the given* column as a split criterion*/protected function getSplitNodesByColumn(int $column, DecisionTreeLeaf $node): array{if ($node->isTerminal) {return [];}$nodes = [];if ($node->columnIndex === $column) {$nodes[] = $node;}$lNodes = [];$rNodes = [];if ($node->leftLeaf !== null) {$lNodes = $this->getSplitNodesByColumn($column, $node->leftLeaf);}if ($node->rightLeaf !== null) {$rNodes = $this->getSplitNodesByColumn($column, $node->rightLeaf);}return array_merge($nodes, $lNodes, $rNodes);}/*** @return mixed*/protected function predictSample(array $sample){$node = $this->tree;do {if ($node->isTerminal) {return $node->classValue;}if ($node->evaluate($sample)) {$node = $node->leftLeaf;} else {$node = $node->rightLeaf;}} while ($node);return $this->labels[0];}}