Sentence pair classification
Albert预训练模型下载
GitHub上下载预训练参数和Albert源码。
数据的读取
修改源码中run_classifier.py中的DataProcessor为自己任务的数据处理类。
数据格式
!---! | !---! | !---! |
|0或1|句子1|句子2|
修改输出
在run_classifier.py中 在main函数中,do_predict部分: 原输出为概率值,二分类修改为0/1。
Sentence (and sentence-pair) classification tasks
Before running this example you must download the GLUE data by running this script and unpack it to some directory $GLUE_DIR. Next, download the BERT-Base checkpoint and unzip it to some directory $BERT_BASE_DIR.
This example code fine-tunes BERT-Base on the Microsoft Research Paraphrase Corpus (MRPC) corpus, which only contains 3,600 examples and can fine-tune in a few minutes on most GPUs.
export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12 export GLUE_DIR=/path/to/glue
python run_classifier.py
--task_name=MRPC
--do_train=true
--do_eval=true
--data_dir=\(GLUE_DIR/MRPC \
--vocab_file=\)BERT_BASE_DIR/vocab.txt
--bert_config_file=\(BERT_BASE_DIR/bert_config.json \
--init_checkpoint=\)BERT_BASE_DIR/bert_model.ckpt
--max_seq_length=128
--train_batch_size=32
--learning_rate=2e-5
--num_train_epochs=3.0
--output_dir=/tmp/mrpc_output/ You should see output like this:
***** Eval results ***** eval_accuracy = 0.845588 eval_loss = 0.505248 global_step = 343 loss = 0.505248 This means that the Dev set accuracy was 84.55%. Small sets like MRPC have a high variance in the Dev set accuracy, even when starting from the same pre-training checkpoint. If you re-run multiple times (making sure to point to different output_dir), you should see results between 84% and 88%.
A few other pre-trained models are implemented off-the-shelf in run_classifier.py, so it should be straightforward to follow those examples to use BERT for any single-sentence or sentence-pair classification task.
Note: You might see a message Running train on CPU. This really just means that it's running on something other than a Cloud TPU, which includes a GPU.
Prediction from classifier
Once you have trained your classifier you can use it in inference mode by using the --do_predict=true command. You need to have a file named test.tsv in the input folder. Output will be created in file called test_results.tsv in the output folder. Each line will contain output for each sample, columns are the class probabilities.
export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12 export GLUE_DIR=/path/to/glue export TRAINED_CLASSIFIER=/path/to/fine/tuned/classifier
python run_classifier.py
--task_name=MRPC
--do_predict=true
--data_dir=\(GLUE_DIR/MRPC \
--vocab_file=\)BERT_BASE_DIR/vocab.txt
--bert_config_file=\(BERT_BASE_DIR/bert_config.json \
--init_checkpoint=\)TRAINED_CLASSIFIER
--max_seq_length=128
--output_dir=/tmp/mrpc_output/