![](https://images.cnblogs.com/OutliningIndicators/ContractedBlock.gif)
![](https://images.cnblogs.com/OutliningIndicators/ExpandedBlockStart.gif)
1 # coding=utf-8 2 import pandas as pd 3 import numpy as np 4 from sklearn import cross_validation 5 import tensorflow as tf 6 7 global flag 8 flag=0 9 10 def DataPreprocessing():11 abalone = pd.read_csv("train_data.csv", sep=',', header=0, keep_default_na=True)12 X_train=np.array(abalone.iloc[:,:6])13 Y_train_=np.array(abalone.iloc[:,6:])14 print(X_train)15 Y_train=[]16 for i in range(len(X_train)):17 18 X_train[i][0] = ord(X_train[i][0])-9719 X_train[i][2] = ord(X_train[i][2])-9720 X_train[i][4] = ord(X_train[i][4])-9721 22 # for i in range (len(X_train)):23 # for j in range(6):24 # X_train[i][j]=X_train[i][j]-0.025 #26 #X_train.astype(np.float64)27 # print(X_train,type(X_train),X_train[0][0],type(X_train[0][0]))28 29 #binary classifier30 for i in range(len(Y_train_)):31 32 if Y_train_[i][0]=="draw":33 Y_train.append(0)34 else:35 Y_train.append(1)36 37 38 # multiple classifer39 40 return cross_validation.train_test_split(X_train,Y_train,test_size=0.25,random_state=0,stratify=Y_train)41 42 43 def GetInputs():44 global flag45 X_train, X_test, Y_train, Y_test = DataPreprocessing()46 47 #print(type(X_train),type(X_train[0][0]))48 #print(X_train)49 # print(len(X_test))50 # print(len(Y_train))51 # print(len(Y_test))52 53 54 #X_train[X_train.isnull().any(axis=1)]55 #X_train.fillna('',inplace=True)56 57 # print(X_train)58 # print(Y_test)59 60 x_train=tf.constant(X_train)61 y_train=tf.constant(Y_train)62 x_test=tf.constant(X_test)63 y_test=tf.constant(Y_test)64 #65 # print(x_train)66 # print(y_train)67 # print(x_test)68 # print(y_test)69 70 if flag==0:71 return x_train,y_train72 else:73 return x_test,y_test74 75 76 def Main():77 78 global flag79 80 feature_columns=[tf.contrib.layers.real_valued_column("",dimension=6)]81 82 clf=tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,hidden_units=[20,40,20],n_classes=2,model_dir="/home/jiangjing/TensorflowModel/chess")83 84 clf.fit(input_fn=GetInputs,steps=2000)85 86 flag=187 accuracy_score=clf.evaluate(input_fn=GetInputs,steps=1)["accuracy"]88 89 print("nTest Accuracy:{0:f}".format(accuracy_score))90 91 if __name__ =="__main__":92 #DataPreprocessing()93 94 Main()95 96 exit(0)