TensorFlow NMT数据处理实战解析

在tensorflow/nmt项目中,训练数据和推断数据的输入使用了新的Dataset API,应该是tensorflow 1.2之后引入的API,方便数据的操作。如果你还在使用老的Queue和Coordinator的方式,建议升级高版本的tensorflow并且使用Dataset API。

本教程将从训练数据和推断数据两个方面,详解解析数据的具体处理过程,你将看到文本数据如何转化为模型所需要的实数,以及中间的张量的维度是怎么样的,batch_size和其他超参数又是如何作用的。

训练数据的处理

先来看看训练数据的处理。训练数据的处理比推断数据的处理稍微复杂一些,弄懂了训练数据的处理过程,就可以很轻松地理解推断数据的处理。

训练数据的处理代码位于nmt/utils/iterator_utils.py文件内的​​get_iterator​​函数。

函数的参数

我们先来看看这个函数所需要的参数是什么意思:


参数解释
​​src_dataset​​源数据集
​​tgt_dataset​​目标数据集
​​src_vocab_table​​源数据单词查找表,就是个单词和int类型数据的对应表
​​tgt_vocab_table​​目标数据单词查找表,就是个单词和int类型数据的对应表
​​batch_size​​批大小
​​sos​​句子开始标记
​​eos​​句子结尾标记
​​random_seed​​随机种子,用来打乱数据集的
​​num_buckets​​桶数量
​​src_max_len​​源数据最大长度
​​tgt_max_len​​目标数据最大长度
​​num_parallel_calls​​并发处理数据的并发数
​​output_buffer_size​​输出缓冲区大小
​​skip_count​​跳过数据行数
​​num_shards​​将数据集分片的数量,分布式训练中有用
​​shard_index​​数据集分片后的id
​​reshuffle_each_iteration​​是否每次迭代都重新打乱顺序

上面的解释,如果有不清楚的,可以查看我之前一片介绍超参数的文章:

​ ​tensorflow_nmt的超参数详解​​



我们首先搞清楚几个重要的参数是怎么来的。

​src_dataset​​和​​tgt_dataset​​是我们的训练数据集,他们是逐行一一对应的。比如我们有两个文件​​src_data.txt​​和​​tgt_data.txt​​分别对应训练数据的源数据和目标数据,那么它们的Dataset如何创建的呢?其实利用Dataset API很简单:



src_dataset=tf.data.TextLineDataset('src_data.txt')  
tgt_dataset=tf.data.TextLineDataset('tgt_data.txt')  1.2.


这就是上述函数中的两个参数​​src_dataset​​和​​tgt_dataset​​的由来。

​src_vocab_table​​和​​tgt_vocab_table​​是什么呢?同样顾名思义,就是这两个分别代表源数据词典的查找表和目标数据词典的查找表,实际上查找表就是一个字符串到数字的映射关系。当然,如果我们的源数据和目标数据使用的是同一个词典,那么这两个查找表的内容是一模一样的。很容易想到,肯定也有一种数字到字符串的映射表,这是肯定的,因为神经网络的数据是数字,而我们需要的目标数据是字符串,因此它们之间肯定有一个转换的过程,这个时候,就需要我们的reverse_vocab_table来作用了。

我们看看这两个表是怎么构建出来的呢?代码很简单,利用tensorflow库中定义的lookup_ops即可:



def create_vocab_tables(src_vocab_file, tgt_vocab_file, share_vocab):  """
Creates vocab tables for src_vocab_file and tgt_vocab_file."""  
src_vocab_table = lookup_ops.index_table_from_file(      
src_vocab_file, default_value=UNK_ID)  if share_vocab:    
tgt_vocab_table = src_vocab_table  else:    
tgt_vocab_table = lookup_ops.index_table_from_file(        
tgt_vocab_file, default_value=UNK_ID)  return src_vocab_table, tgt_vocab_table1.2.3.4.5.6.7.8.9.10.


我们可以发现,创建这两个表的过程,就是将词典中的每一个词,对应一个数字,然后返回这些数字的集合,这就是所谓的词典查找表。效果上来说,就是对词典中的每一个词,从0开始递增的分配一个数字给这个词。

那么到这里你有可能会有疑问,我们词典中的词和我们自定义的标记​​sos​​等是不是有可能被映射为同一个整数而造成冲突?这个问题该如何解决?聪明如你,这个问题是存在的。那么我们的项目是如何解决的呢?很简单,那就是将我们自定义的标记当成词典的单词,然后加入到词典文件中,这样一来,​​lookup_ops​​操作就把标记当成单词处理了,也就就解决了冲突!

具体的过程,本文后面会有一个例子,可以为您呈现具体过程。

如果我们指定了​​share_vocab​​参数,那么返回的源单词查找表和目标单词查找表是一样的。我们还可以指定一个default_value,在这里是​​UNK_ID​​,实际上就是​​0​​。如果不指定,那么默认值为​​-1​​。这就是查找表的创建过程。如果你想具体的知道其代码实现,可以跳转到tensorflow的C++核心部分查看代码(使用PyCharm或者类似的IDE)。

数据集的处理过程

该函数处理训练数据的主要代码如下:


if not output_buffer_size:    output_buffer_size = batch_size * 1000  src_eos_id = tf
.cast(src_vocab_table.lookup(tf.constant(eos)), tf.int32)  tgt_sos_id = tf.cast(tgt_vocab_table
.lookup(tf.constant(sos)), tf.int32)  tgt_eos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(eos)), tf
.int32)  src_tgt_dataset = tf.data.Dataset.zip((src_dataset, tgt_dataset))  src_tgt_dataset = src_tgt_dataset
.shard(num_shards, shard_index)  if skip_count is not None:    src_tgt_dataset = src_tgt_dataset
.skip(skip_count)  src_tgt_dataset = src_tgt_dataset.shuffle(      
output_buffer_size, random_seed, reshuffle_each_iteration)  src_tgt_dataset = src_tgt_dataset.map(      
lambda src, tgt: (          tf.string_split([src]).values, tf.string_split([tgt]).values),      
num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)  # Filter zero length input sequences.  
src_tgt_dataset = src_tgt_dataset.filter(      
lambda src, tgt: tf.logical_and(tf.size(src) > 0, tf.size(tgt) > 0))  
if src_max_len:    src_tgt_dataset = src_tgt_dataset.map(        lambda src, 
tgt: (src[:src_max_len], tgt),        num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)  
if tgt_max_len:    src_tgt_dataset = src_tgt_dataset.map(        
lambda src, tgt: (src, tgt[:tgt_max_len1.2.3.4.5.6.7.8.9.10.11.12.13.14.15.16.17.18.19.20.21.22.23.24.25.26.
27.28.29.30.31.



免责声明:本文系网络转载或改编,未找到原创作者,版权归原作者所有。如涉及版权,请联系删

相关推荐
技术文档
软件下载
QR Code
微信扫一扫,欢迎咨询~

联系我们
武汉格发信息技术有限公司
湖北省武汉市经开区科技园西路6号103孵化器
电话:155-2731-8020 座机:027-59821821
邮件:tanzw@gofarlic.com
Copyright © 2023 Gofarsoft Co.,Ltd. 保留所有权利
遇到许可问题?该如何解决!?
评估许可证实际采购量? 
不清楚软件许可证使用数据? 
收到软件厂商律师函!?  
想要少购买点许可证,节省费用? 
收到软件厂商侵权通告!?  
有正版license,但许可证不够用,需要新购? 
联系方式 155-2731-8020
预留信息,一起解决您的问题
* 姓名:
* 手机:

* 公司名称:

姓名不为空

手机不正确

公司不为空