AutorÃa | Ultima modificación | Ver Log |
<?phpdeclare(strict_types=1);namespace Phpml\Classification\Linear;use Phpml\Classification\DecisionTree;use Phpml\Classification\WeightedClassifier;use Phpml\Exception\InvalidArgumentException;use Phpml\Helper\OneVsRest;use Phpml\Helper\Predictable;use Phpml\Math\Comparison;class DecisionStump extends WeightedClassifier{use Predictable;use OneVsRest;public const AUTO_SELECT = -1;/*** @var int*/protected $givenColumnIndex;/*** @var array*/protected $binaryLabels = [];/*** Lowest error rate obtained while training/optimizing the model** @var float*/protected $trainingErrorRate;/*** @var int*/protected $column;/*** @var mixed*/protected $value;/*** @var string*/protected $operator;/*** @var array*/protected $columnTypes = [];/*** @var int*/protected $featureCount;/*** @var float*/protected $numSplitCount = 100.0;/*** Distribution of samples in the leaves** @var array*/protected $prob = [];/*** A DecisionStump classifier is a one-level deep DecisionTree. It is generally* used with ensemble algorithms as in the weak classifier role. <br>** If columnIndex is given, then the stump tries to produce a decision node* on this column, otherwise in cases given the value of -1, the stump itself* decides which column to take for the decision (Default DecisionTree behaviour)*/public function __construct(int $columnIndex = self::AUTO_SELECT){$this->givenColumnIndex = $columnIndex;}public function __toString(): string{return "IF {$this->column} {$this->operator} {$this->value} ".'THEN '.$this->binaryLabels[0].' '.'ELSE '.$this->binaryLabels[1];}/*** While finding best split point for a numerical valued column,* DecisionStump looks for equally distanced values between minimum and maximum* values in the column. Given <i>$count</i> value determines how many split* points to be probed. The more split counts, the better performance but* worse processing time (Default value is 10.0)*/public function setNumericalSplitCount(float $count): void{$this->numSplitCount = $count;}/*** @throws InvalidArgumentException*/protected function trainBinary(array $samples, array $targets, array $labels): void{$this->binaryLabels = $labels;$this->featureCount = count($samples[0]);// If a column index is given, it should be among the existing columnsif ($this->givenColumnIndex > count($samples[0]) - 1) {$this->givenColumnIndex = self::AUTO_SELECT;}// Check the size of the weights given.// If none given, then assign 1 as a weight to each sampleif (count($this->weights) === 0) {$this->weights = array_fill(0, count($samples), 1);} else {$numWeights = count($this->weights);if ($numWeights !== count($samples)) {throw new InvalidArgumentException('Number of sample weights does not match with number of samples');}}// Determine type of each column as either "continuous" or "nominal"$this->columnTypes = DecisionTree::getColumnTypes($samples);// Try to find the best split in the columns of the dataset// by calculating error rate for each split point in each column$columns = range(0, count($samples[0]) - 1);if ($this->givenColumnIndex !== self::AUTO_SELECT) {$columns = [$this->givenColumnIndex];}$bestSplit = ['value' => 0,'operator' => '','prob' => [],'column' => 0,'trainingErrorRate' => 1.0,];foreach ($columns as $col) {if ($this->columnTypes[$col] == DecisionTree::CONTINUOUS) {$split = $this->getBestNumericalSplit($samples, $targets, $col);} else {$split = $this->getBestNominalSplit($samples, $targets, $col);}if ($split['trainingErrorRate'] < $bestSplit['trainingErrorRate']) {$bestSplit = $split;}}// Assign determined best values to the stumpforeach ($bestSplit as $name => $value) {$this->{$name} = $value;}}/*** Determines best split point for the given column*/protected function getBestNumericalSplit(array $samples, array $targets, int $col): array{$values = array_column($samples, $col);// Trying all possible points may be accomplished in two general ways:// 1- Try all values in the $samples array ($values)// 2- Artificially split the range of values into several parts and try them// We choose the second one because it is faster in larger datasets$minValue = min($values);$maxValue = max($values);$stepSize = ($maxValue - $minValue) / $this->numSplitCount;$split = [];foreach (['<=', '>'] as $operator) {// Before trying all possible split points, let's first try// the average value for the cut point$threshold = array_sum($values) / (float) count($values);[$errorRate, $prob] = $this->calculateErrorRate($targets, $threshold, $operator, $values);if (!isset($split['trainingErrorRate']) || $errorRate < $split['trainingErrorRate']) {$split = ['value' => $threshold,'operator' => $operator,'prob' => $prob,'column' => $col,'trainingErrorRate' => $errorRate,];}// Try other possible points one by onefor ($step = $minValue; $step <= $maxValue; $step += $stepSize) {$threshold = (float) $step;[$errorRate, $prob] = $this->calculateErrorRate($targets, $threshold, $operator, $values);if ($errorRate < $split['trainingErrorRate']) {$split = ['value' => $threshold,'operator' => $operator,'prob' => $prob,'column' => $col,'trainingErrorRate' => $errorRate,];}}// for}return $split;}protected function getBestNominalSplit(array $samples, array $targets, int $col): array{$values = array_column($samples, $col);$valueCounts = array_count_values($values);$distinctVals = array_keys($valueCounts);$split = [];foreach (['=', '!='] as $operator) {foreach ($distinctVals as $val) {[$errorRate, $prob] = $this->calculateErrorRate($targets, $val, $operator, $values);if (!isset($split['trainingErrorRate']) || $split['trainingErrorRate'] < $errorRate) {$split = ['value' => $val,'operator' => $operator,'prob' => $prob,'column' => $col,'trainingErrorRate' => $errorRate,];}}}return $split;}/*** Calculates the ratio of wrong predictions based on the new threshold* value given as the parameter*/protected function calculateErrorRate(array $targets, float $threshold, string $operator, array $values): array{$wrong = 0.0;$prob = [];$leftLabel = $this->binaryLabels[0];$rightLabel = $this->binaryLabels[1];foreach ($values as $index => $value) {if (Comparison::compare($value, $threshold, $operator)) {$predicted = $leftLabel;} else {$predicted = $rightLabel;}$target = $targets[$index];if ((string) $predicted != (string) $targets[$index]) {$wrong += $this->weights[$index];}if (!isset($prob[$predicted][$target])) {$prob[$predicted][$target] = 0;}++$prob[$predicted][$target];}// Calculate probabilities: Proportion of labels in each leaf$dist = array_combine($this->binaryLabels, array_fill(0, 2, 0.0));foreach ($prob as $leaf => $counts) {$leafTotal = (float) array_sum($prob[$leaf]);foreach ($counts as $label => $count) {if ((string) $leaf == (string) $label) {$dist[$leaf] = $count / $leafTotal;}}}return [$wrong / (float) array_sum($this->weights), $dist];}/*** Returns the probability of the sample of belonging to the given label** Probability of a sample is calculated as the proportion of the label* within the labels of the training samples in the decision node** @param mixed $label*/protected function predictProbability(array $sample, $label): float{$predicted = $this->predictSampleBinary($sample);if ((string) $predicted == (string) $label) {return $this->prob[$label];}return 0.0;}/*** @return mixed*/protected function predictSampleBinary(array $sample){if (Comparison::compare($sample[$this->column], $this->value, $this->operator)) {return $this->binaryLabels[0];}return $this->binaryLabels[1];}protected function resetBinary(): void{}}