python - What should do to fix my scikit-learn program? -
a snippet of code involving randomforestclassifier using python machine learning library scikit-learn.
i trying give weight different classes using class_weight opition in scikit's randomforestclassifier.below code snippet , error getting
print 'training...' forest = randomforestclassifier(n_estimators=500,class_weight= {0:1,1:1,2:1,3:1,4:1,5:1,6:1,7:4}) forest = forest.fit( train_data[0::,1::], train_data[0::,0] ) print 'predicting...' output = forest.predict(test_data).astype(int) predictions_file = open("myfirstforest.csv", "wb") open_file_object = csv.writer(predictions_file) open_file_object.writerow(["passengerid","survived"]) open_file_object.writerows(zip(ids, output)) predictions_file.close() print 'done.'
and getting following error:
training... indexerror traceback (most recent call last) <ipython-input-20-122f2e5a0d3b> in <module>() 84 print 'training...' 85 forest = randomforestclassifier(n_estimators=500,class_weight={0:1,1:1,2:1,3:1,4:1,5:1,6:1,7:4}) ---> 86 forest = forest.fit( train_data[0::,1::], train_data[0::,0] ) 87 88 print 'predicting...' /home/rpota/anaconda/lib/python2.7/site-packages/sklearn/ensemble/forest.pyc in fit(self, x, y, sample_weight) 216 self.n_outputs_ = y.shape[1] 217 --> 218 y, expanded_class_weight = self._validate_y_class_weight(y) 219 220 if getattr(y, "dtype", none) != double or not y.flags.contiguous: /home/rpota/anaconda/lib/python2.7/site-packages/sklearn/ensemble/forest.pyc in _validate_y_class_weight(self, y) 433 class_weight = self.class_weight 434 expanded_class_weight = compute_sample_weight(class_weight, --> 435 y_original) 436 437 return y, expanded_class_weight /home/rpota/anaconda/lib/python2.7/site-packages/sklearn/utils/class_weight.pyc in compute_sample_weight(class_weight, y, indices) 150 weight_k = compute_class_weight(class_weight_k, 151 classes_full, --> 152 y_full) 153 154 weight_k = weight_k[np.searchsorted(classes_full, y_full)] /home/rpota/anaconda/lib/python2.7/site-packages/sklearn/utils/class_weight.pyc in compute_class_weight(class_weight, classes, y) 58 c in class_weight: 59 = np.searchsorted(classes, c) ---> 60 if classes[i] != c: 61 raise valueerror("class label %d not present." % c) 62 else: indexerror: index 2 out of bounds axis 0 size 2
please help!.
Comments
Post a Comment