from common.utils import plot

    xiaoxiao2024-12-13  51

    使用以下函数(函数来源官网):

    def plot_learning_curve(estimator, title, X, y, ylim=None, cv=None,                         n_jobs=1, train_sizes=np.linspace(.1, 1.0, 5)):     plt.figure()     plt.title(title)     if ylim is not None:         plt.ylim(*ylim)     plt.xlabel("Training examples")     plt.ylabel("Score")     train_sizes, train_scores, test_scores = learning_curve(         estimator, X, y, cv=cv, n_jobs=n_jobs, train_sizes=train_sizes)     train_scores_mean = np.mean(train_scores, axis=1)     train_scores_std = np.std(train_scores, axis=1)     test_scores_mean = np.mean(test_scores, axis=1)     test_scores_std = np.std(test_scores, axis=1)     plt.grid()       plt.fill_between(train_sizes, train_scores_mean - train_scores_std,                      train_scores_mean + train_scores_std, alpha=0.1,                      color="r")     plt.fill_between(train_sizes, test_scores_mean - test_scores_std,                      test_scores_mean + test_scores_std, alpha=0.1, color="g")     plt.plot(train_sizes, train_scores_mean, 'o-', color="r",              label="Training score")     plt.plot(train_sizes, test_scores_mean, 'o-', color="g",              label="Cross-validation score")       plt.legend(loc="best")     return plt

     

    #使用示例:

    knn = KNeighborsClassifier()

    cv = ShuffleSplit(n_splits=10,test_size=0.2,random_state=0)

    plot_learning_curve(knn,'aaa',X,y,ylim=(0.0,1.01),cv=cv)

     

    最新回复(0)