博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
使用TensorFlow训练Weibo.cn验证码
阅读量:7031 次
发布时间:2019-06-28

本文共 7446 字,大约阅读时间需要 24 分钟。

最近在抽时间学习TensorFlow这个DL库的使用,学的断断续续的,看官网上第一个案例就是。我之前在做Weibo.cn验证码识别的时候,自己搞了一个数据集,当时用的c++库tiny-dnn进行训练的(见:),现在我把它移植到TensorFlow上试试。

完整代码见:

使用的库

  • TensorFlow-1.0

  • scikit-learn-0.18

  • pillow

加载数据集

数据集下载地址:

解压过后如下图:

dataset

我把同一类的图片放到了一个文件夹里,文件夹的名字也就是图片的label,打开文件夹后可以看到字符的图片信息。

dataset_detail

下面,我们把数据加载到一个pickle文件里面,它需要有train_dataset、train_labels、test_dataset、test_labels四个变量代表训练集和测试集的数据和标签。

此外,还需要有个label_map,用来把训练的标签和实际的标签对应,比如说3对应字母M,4对应字母N。

此部分的代码见:。注:很多的代码参考自udacity的deeplearning课程。

首先根据文件夹的来加载所有的数据,index代表训练里的标签,label代表实际的标签,使用PIL读取图片,并转换成numpy数组。

import numpy as npimport osfrom PIL import Imagedef load_dataset():    dataset = []    labelset = []    label_map = {}    base_dir = "../trainer/training_set/"  # 数据集的位置    labels = os.listdir(base_dir)    for index, label in enumerate(labels):        if label == "ERROR" or label == ".DS_Store":            continue        print "loading:", label, "index:", index        try:            image_files = os.listdir(base_dir + label)            for image_file in image_files:                image_path = base_dir + label + "/" + image_file                im = Image.open(image_path).convert('L')                dataset.append(np.asarray(im, dtype=np.float32))                labelset.append(index)            label_map[index] = label        except: pass    return np.array(dataset), np.array(labelset), label_mapdataset, labelset, label_map = load_dataset()

接下来,把数据打乱。

def randomize(dataset, labels):    permutation = np.random.permutation(labels.shape[0])    shuffled_dataset = dataset[permutation, :, :]    shuffled_labels = labels[permutation]    return shuffled_dataset, shuffled_labelsdataset, labelset = randomize(dataset, labelset)

然后使用scikit-learn的函数,把训练集和测试集分开。

from sklearn.model_selection import train_test_splittrain_dataset, test_dataset, train_labels, test_labels = train_test_split(dataset, labelset)

在TensorFlow官网给的例子中,会把label进行One-Hot Encoding,并把28*28的图片转换成了一维向量(784)。如下图,查看官网例子的模型。

minist_data

我也把数据转换了一下,把32*32的图片转换成一维向量(1024),并对标签进行One-Hot Encoding。

def reformat(dataset, labels, image_size, num_labels):    dataset = dataset.reshape((-1, image_size * image_size)).astype(np.float32)    # Map 1 to [0.0, 1.0, 0.0 ...], 2 to [0.0, 0.0, 1.0 ...]    labels = (np.arange(num_labels) == labels[:, None]).astype(np.float32)    return dataset, labelstrain_dataset, train_labels = reformat(train_dataset, train_labels, 32, len(label_map))test_dataset, test_labels = reformat(test_dataset, test_labels, 32, len(label_map))print "train_dataset:", train_dataset.shapeprint "train_labels:", train_labels.shapeprint "test_dataset:", test_dataset.shapeprint "test_labels:", test_labels.shape

转换后,格式就和minist一样了。

reformat

最后,把数据保存到save.pickle里面。

save = {    'train_dataset': train_dataset,    'train_labels': train_labels,    'test_dataset': test_dataset,    'test_labels': test_labels,    'label_map': label_map}with open("save.pickle", 'wb') as f:    pickle.dump(save, f)

验证数据集加载是否正确

加载完数据后,需要验证一下数据是否正确。我选择的方法很简单,就是把trainset的第1个(或者第2个、第n个)图片打开,看看它的标签和看到的能不能对上。

import cPickle as picklefrom PIL import Imageimport numpy as npdef check_dataset(dataset, labels, label_map, index):    data = np.uint8(dataset[index]).reshape((32, 32))    i = np.argwhere(labels[index] == 1)[0][0]    im = Image.fromarray(data)    im.show()    print "label:", label_map[i]if __name__ == '__main__':    with open("save.pickle", 'rb') as f:        save = pickle.load(f)        train_dataset = save['train_dataset']        train_labels = save['train_labels']        test_dataset = save['test_dataset']        test_labels = save['test_labels']        label_map = save['label_map']    # check if the image is corresponding to it's label    check_dataset(train_dataset, train_labels, label_map, 0)

运行后,可以看到第一张图片是Y,标签也是正确的。

check_dataset

训练

数据加载好了之后,就可以开始训练了,训练的网络就使用TensorFlow官网在里提供的就好了。

此部分的代码见:。

先加载一下模型:

import cPickle as pickleimport numpy as npimport tensorflow as tfwith open("save.pickle", 'rb') as f:    save = pickle.load(f)    train_dataset = save['train_dataset']    train_labels = save['train_labels']    test_dataset = save['test_dataset']    test_labels = save['test_labels']    label_map = save['label_map']image_size = 32num_labels = len(label_map)print "train_dataset:", train_dataset.shapeprint "train_labels:", train_labels.shapeprint "test_dataset:", test_dataset.shapeprint "test_labels:", test_labels.shapeprint "num_labels:", num_labels

minist的数据都是28*28的,把里面的网络改完了之后,如下:

def weight_variable(shape):    initial = tf.truncated_normal(shape, stddev=0.1)    return tf.Variable(initial)def bias_variable(shape):    initial = tf.constant(0.1, shape=shape)    return tf.Variable(initial)def conv2d(x, W):    return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')def max_pool_2x2(x):    return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],                          strides=[1, 2, 2, 1], padding='SAME')graph = tf.Graph()with graph.as_default():    x = tf.placeholder(tf.float32, shape=[None, image_size * image_size])    y_ = tf.placeholder(tf.float32, shape=[None, num_labels])    x_image = tf.reshape(x, [-1, 32, 32, 1])    # First Convolutional Layer    W_conv1 = weight_variable([5, 5, 1, 32])    b_conv1 = bias_variable([32])    h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)    h_pool1 = max_pool_2x2(h_conv1)    # Second Convolutional Layer    W_conv2 = weight_variable([5, 5, 32, 64])    b_conv2 = bias_variable([64])    h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)    h_pool2 = max_pool_2x2(h_conv2)    # Densely Connected Layer    W_fc1 = weight_variable([image_size / 4 * image_size / 4 * 64, 1024])    b_fc1 = bias_variable([1024])    h_pool2_flat = tf.reshape(h_pool2, [-1, image_size / 4 * image_size / 4 * 64])    h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)    # Dropout    keep_prob = tf.placeholder(tf.float32)    h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)    # Readout Layer    W_fc2 = weight_variable([1024, num_labels])    b_fc2 = bias_variable([num_labels])    y_conv = tf.matmul(h_fc1_drop, W_fc2) + b_fc2    cross_entropy = tf.reduce_mean(        tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_conv))    train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)    correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1))    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

主要改动就是输入层把28*28改成了image_size*image_size(32*32),然后第三层的全连接网络把7*7改成了image_size/4*image_size/4(8*8),以及把10(手写字符一共10类)改成了num_labels。

然后训练,我这里把batch_size改成了128,训练批次改少了。

batch_size = 128with tf.Session(graph=graph) as session:    tf.global_variables_initializer().run()    print("Initialized")    for step in range(2001):        offset = (step * batch_size) % (train_labels.shape[0] - batch_size)        # Generate a minibatch.        batch_data = train_dataset[offset:(offset + batch_size), :]        batch_labels = train_labels[offset:(offset + batch_size), :]        if step % 50 == 0:            train_accuracy = accuracy.eval(feed_dict={                x: batch_data, y_: batch_labels, keep_prob: 1.0})            test_accuracy = accuracy.eval(feed_dict={                x: test_dataset, y_: test_labels, keep_prob: 1.0})            print("Step %d, Training accuracy: %g, Test accuracy: %g" % (step, train_accuracy, test_accuracy))        train_step.run(feed_dict={x: batch_data, y_: batch_labels, keep_prob: 0.5})    print("Test accuracy: %g" % accuracy.eval(feed_dict={        x: test_dataset, y_: test_labels, keep_prob: 1.0}))

运行,可以看到识别率在不断的上升。

train

最后,有了接近98%的识别率,只有4000个训练数据,感觉不错了。

train_last

转载地址:http://xigxl.baihongyu.com/

你可能感兴趣的文章
springboot架构下运用shiro后在configuration,通过@Value获取不到值,总是为null
查看>>
SQLServer 数据库镜像+复制切换方案
查看>>
Postman初探
查看>>
仿淘宝头像上传功能(一)——前端篇。
查看>>
Eclipse通过集成svn实现版本控制
查看>>
OS开发过程中常用开源库
查看>>
关于在多个UItextield切换焦点
查看>>
hdu 2768
查看>>
git记住用户名密码
查看>>
ElasticSearch(2)-安装ElasticSearch
查看>>
从mysql数据表中随机取出一条记录
查看>>
ORACLE 锁表处理,解锁释放session
查看>>
深海机器人问题
查看>>
正则表达式(括号)、[中括号]、{大括号}的区别小结
查看>>
88.NODE.JS加密模块CRYPTO常用方法介绍
查看>>
java.net.ProtocolException: Exceeded stated content-length of: '13824' bytes
查看>>
asp.net 连接 oracle10g 数据库
查看>>
C 入门 第十一节
查看>>
HTML简单的注册页面搭建
查看>>
【06】Vue 之 组件化开发
查看>>