import numpy as np
from sklearn.utils import check_arrays from sklearn.cross_validation import train_test_split from sklearn import datasets import osos.system("clear")
########################## class myBayes: def __init__(self): self.Px={} self.Py={} self.nx=0 self.lx=0 self.result=Nonedef fit(self, X, y):
Py={} Px={} k=len(np.unique(y)) for i in list(set(y)): Py[i]=(y.tolist().count(i)+1)*1.0/(len(y)+k) n_row, n_col=X.shape for i in range(n_col): Px.setdefault(i, {}) xylist=zip(X[:,i],y) s=len(np.unique(X[:,i])) for xy in list(set(xylist)): Px[i][xy]=(xylist.count(xy)+1)*1.0/(Py[xy[1]]*(len(y)+k)+s-1) self.Py=Py self.Px=Px self.nx=n_col #print 'y',self.Py #print 'x',self.Pxdef predict(self, test_X):
tX=np.array(test_X) ts=tX.shape if len(ts)==0 : return None elif len(ts)==1 : if len(tX)!=self.nx: return None else: result={} for i in self.Py.keys(): py=self.Py[i] for j in range(ts[0]): py=py*self.Px[j].get((tX[j],i),1-sum([Px[j][t] for t in Px[j].keys() if t[1]==i])) result[py]=i self.result=np.array(result[max(result.keys())]) return self.resultelif len(ts)==2:
if ts[1]!=self.nx: return None else: result_list=[] for x in tX: #print x result={} for i in self.Py.keys(): py=self.Py[i] for j in range(ts[1]): py=py*self.Px[j].get((x[j],i),0) result[py]=i #print result result_list.append(result[max(result.keys())]) self.result=np.array(result_list) return self.result#x1=[1,1,1,1,1,2,2,2,2,2,3,3,3,3,3]
#x2=['s','m','m','s','s','s','m','m','l','l','l','m','m','l','l'] #y=[0,0,1,1,0,0,0,1,1,1,1,1,1,1,0] #x2dict={'s':1,'m':2,'l':3} #X=np.array(zip(x1,[x2dict[x] for x in x2])) #y=np.array(y) #print X,y iris=datasets.load_iris() X=iris.data y=iris.targettrainX,testX,trainy,testy=train_test_split(X,y,test_size=0.2)
clf=myBayes()
clf.fit(trainX,trainy)predicted=clf.predict(testX)
print testy print predictedprint np.mean(testy==predicted)
###########################################
输出结果:
yuanzhen@yuanzhen-ThinkPad-X121e:~/P_script$ python mybayes.py
[2 0 2 0 2 0 2 1 2 0 0 0 1 2 2 1 2 0 2 1 2 2 2 1 1 2 1 1 0 2]
[2 0 2 0 2 0 2 1 2 0 0 2 2 2 2 1 2 0 2 1 2 1 2 2 1 2 1 1 0 2] 0.866666666667[0 1 1 0 1 2 1 1 0 0 0 1 1 0 2 1 2 0 1 2 0 2 0 2 2 2 2 2 0 0]
[0 1 2 0 1 2 1 1 0 0 0 1 2 0 1 1 1 0 1 2 2 2 0 2 2 2 2 1 2 0] 0.766666666667结果显示预测并不稳定