博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Chess
阅读量:7260 次
发布时间:2019-06-29

本文共 2504 字,大约阅读时间需要 8 分钟。

1 # coding=utf-8 2 import pandas as pd 3 import numpy as np 4 from sklearn import cross_validation 5 import tensorflow as tf 6  7 global flag 8 flag=0 9 10 def DataPreprocessing():11     abalone = pd.read_csv("train_data.csv", sep=',', header=0, keep_default_na=True)12     X_train=np.array(abalone.iloc[:,:6])13     Y_train_=np.array(abalone.iloc[:,6:])14     print(X_train)15     Y_train=[]16     for i in range(len(X_train)):17 18         X_train[i][0] = ord(X_train[i][0])-9719         X_train[i][2] = ord(X_train[i][2])-9720         X_train[i][4] = ord(X_train[i][4])-9721 22     # for i in range (len(X_train)):23     #     for j in range(6):24     #         X_train[i][j]=X_train[i][j]-0.025     #26     #X_train.astype(np.float64)27   #  print(X_train,type(X_train),X_train[0][0],type(X_train[0][0]))28 29     #binary classifier30     for i in range(len(Y_train_)):31 32         if Y_train_[i][0]=="draw":33             Y_train.append(0)34         else:35             Y_train.append(1)36 37 38     # multiple classifer39 40     return cross_validation.train_test_split(X_train,Y_train,test_size=0.25,random_state=0,stratify=Y_train)41 42 43 def GetInputs():44     global flag45     X_train, X_test, Y_train, Y_test = DataPreprocessing()46 47     #print(type(X_train),type(X_train[0][0]))48     #print(X_train)49     # print(len(X_test))50     # print(len(Y_train))51     # print(len(Y_test))52 53 54     #X_train[X_train.isnull().any(axis=1)]55     #X_train.fillna('',inplace=True)56 57     # print(X_train)58     # print(Y_test)59 60     x_train=tf.constant(X_train)61     y_train=tf.constant(Y_train)62     x_test=tf.constant(X_test)63     y_test=tf.constant(Y_test)64     #65     # print(x_train)66     # print(y_train)67     # print(x_test)68     # print(y_test)69 70     if flag==0:71         return x_train,y_train72     else:73         return x_test,y_test74 75 76 def Main():77 78     global flag79 80     feature_columns=[tf.contrib.layers.real_valued_column("",dimension=6)]81 82     clf=tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,hidden_units=[20,40,20],n_classes=2,model_dir="/home/jiangjing/TensorflowModel/chess")83 84     clf.fit(input_fn=GetInputs,steps=2000)85 86     flag=187     accuracy_score=clf.evaluate(input_fn=GetInputs,steps=1)["accuracy"]88 89     print("nTest Accuracy:{0:f}".format(accuracy_score))90 91 if __name__ =="__main__":92     #DataPreprocessing()93 94     Main()95 96 exit(0)
View Code

 

转载于:https://www.cnblogs.com/acm-jing/p/9098996.html

你可能感兴趣的文章
Silverlight实用窍门系列:67.Silverlight下的Socket通讯
查看>>
由于这台计算机没有终端服务器客户端访问许可证无法远程
查看>>
WinForm应用程序实现虚拟键盘
查看>>
最新的Javascript和CSS应用技巧荟萃[简介]
查看>>
linux 访问tomcat 管理页面时 You are not authorized to view this page 403(真实可用)
查看>>
在shell中使用sendmail发送邮件
查看>>
SQL Server :理解DCM页
查看>>
QQ邮箱开启SMTP服务的步骤
查看>>
MIDI Test Procedure
查看>>
Audio Latency
查看>>
想从事分布式系统,计算,hadoop等方面,需要哪些基础,推荐哪些书籍?--转自知乎...
查看>>
线程的生命周期
查看>>
设计模式(七):命令模式
查看>>
sqlserver中sp_executesql使用实例(获取动态sql输出结果)
查看>>
C/C++ 中头文件相互包含引发的问题
查看>>
Hive:动静态分区
查看>>
linux设置系统时间
查看>>
把视图转换为字符串
查看>>
Tomcat项目部署方式【转】
查看>>
Java移动文件到另外一个目录
查看>>