Open In Colab

English to Hindi Transliteration Tutorial

Setup Environment

  • On Google Colab make sure you select Python 3/GPU runtime before running the code

Choose Python 3 + GPU/CPU

</img> </img>

GPU

%env CUDA_VISIBLE_DEVICES=0
env: CUDA_VISIBLE_DEVICES=0

Download Data

!mkdir data
!wget -N "https://raw.githubusercontent.com/bsantraigi/tensorflow-seq2seq-hindi/master/data/Hindi%20-%20Word%20Transliteration%20Pairs%201.txt" -P data/
mkdir: cannot create directory ‘data’: File exists
--2019-08-10 20:10:10--  https://raw.githubusercontent.com/bsantraigi/tensorflow-seq2seq-hindi/master/data/Hindi%20-%20Word%20Transliteration%20Pairs%201.txt
Connecting to 172.16.2.30:8080... connected.
Proxy request sent, awaiting response... 200 OK
Length: 773211 (755K) [text/plain]
Last-modified header missing -- time-stamps turned off.
--2019-08-10 20:10:11--  https://raw.githubusercontent.com/bsantraigi/tensorflow-seq2seq-hindi/master/data/Hindi%20-%20Word%20Transliteration%20Pairs%201.txt
Connecting to 172.16.2.30:8080... connected.
Proxy request sent, awaiting response... 200 OK
Length: 773211 (755K) [text/plain]
Saving to: ‘data/Hindi - Word Transliteration Pairs 1.txt’

100%[======================================>] 7,73,211     714KB/s   in 1.1s   

2019-08-10 20:10:13 (714 KB/s) - ‘data/Hindi - Word Transliteration Pairs 1.txt’ saved [773211/773211]

Import Stuff

import nltk
from collections import Counter
from tqdm import tqdm_notebook
import numpy as np
import tensorflow as tf
from tensorflow.contrib import seq2seq
from tensorflow.contrib.rnn import DropoutWrapper
import random
nltk.download('punkt')
[nltk_data] Downloading package punkt to /home/bishal/nltk_data...
[nltk_data]   Package punkt is already up-to-date!





True

Global Parameters

MAX_SEQ_LEN = 20
BATCH_SIZE = 64

Language Vocabulary

  • (Vocab of characters, i.e. an Alphabet)
class Lang:
    def __init__(self, counter, vocab_size):
        self.word2id = {}
        self.id2word = {}
        self.pad = "<PAD>"
        self.sos = "<SOS>"
        self.eos = "<EOS>"
        self.unk = "<UNK>"
        
        self.ipad = 0
        self.isos = 1
        self.ieos = 2
        self.iunk = 3
        
        self.word2id[self.pad] = 0
        self.word2id[self.sos] = 1
        self.word2id[self.eos] = 2
        self.word2id[self.unk] = 3
        
        self.id2word[0] = self.pad
        self.id2word[1] = self.sos
        self.id2word[2] = self.eos
        self.id2word[3] = self.unk
        
        curr_id = 4
        for w, c in counter.most_common(vocab_size):
            self.word2id[w] = curr_id
            self.id2word[curr_id] = w
            curr_id += 1
            
    def encodeSentence(self, s, max_len=-1):
        wseq = s.lower().strip()
        if max_len == -1:
            return [self.word2id[w] if w in self.word2id else self.iunk for w in wseq]
        else:
            return ([self.word2id[w] if w in self.word2id else self.iunk for w in wseq] + [self.ieos] + [self.ipad]*max_len)[:max_len]
        
    def encodeSentence2(self, s, max_len=-1):
        wseq = wseq = s.lower().strip()
        return min(max_len, len(wseq)+1), \
            ([self.word2id[w] if w in self.word2id else self.iunk for w in wseq] + \
                [self.ieos] + [self.ipad]*max_len)[:max_len]
    
    def decodeSentence(self, id_seq):
        id_seq = np.array(id_seq + [self.ieos])
        j = np.argmax(id_seq==self.ieos)
        s = ''.join([self.id2word[x] for x in id_seq[:j]])
        s = s.replace(self.unk, "UNK")
        return s
# Total number of samples to read
N = 30823

Reading the data files

  • Each line contains a hindi word in both English and Devnagari script
hi_counter = Counter()
hi_sentences=[]
en_counter = Counter()
en_sentences=[]
with open("data/Hindi - Word Transliteration Pairs 1.txt") as f:
    for line in tqdm_notebook(f, total=N, desc="Reading file:"):
        en, hi = line.strip().split("\t")
        hi_sentences.append(hi)
        en_sentences.append(en)
    for line in tqdm_notebook(hi_sentences, desc="Processing inputs:"):
        for w in line.strip():
            hi_counter[w] += 1
    for line in tqdm_notebook(en_sentences, desc="Processing inputs:"):
        for w in line.strip():
            en_counter[w] += 1
HBox(children=(IntProgress(value=0, description='Reading file:', max=30823, style=ProgressStyle(description_wi…






HBox(children=(IntProgress(value=0, description='Processing inputs:', max=30823, style=ProgressStyle(descripti…






HBox(children=(IntProgress(value=0, description='Processing inputs:', max=30823, style=ProgressStyle(descripti…
# A few sample hindi characters
print("Most common hi characters in dataset:\n", hi_counter.most_common(5))

print("\nTotal (hi)characters gathered from dataset:",len(hi_counter))

# A few sample english characters
print("\nMost common en characters in dataset:\n", en_counter.most_common(5))

print("\nTotal (en)characters gathered from dataset:", len(en_counter))
Most common hi characters in dataset:
 [('ा', 21123), ('र', 9205), ('े', 8100), ('न', 7225), ('ी', 6546)]

Total (hi)characters gathered from dataset: 66

Most common en characters in dataset:
 [('a', 57220), ('n', 15015), ('i', 14015), ('h', 13805), ('e', 12264)]

Total (en)characters gathered from dataset: 27
en_lang = Lang(en_counter, len(en_counter))
hi_lang = Lang(hi_counter, len(hi_counter))
print("Test en encoding:", en_lang.encodeSentence("Shukriya"))

print("Test en decoding:", en_lang.decodeSentence(en_lang.encodeSentence("Shukriya", 10)))

print("Test hindi encoding:", hi_lang.encodeSentence("शुक्रिया", 10))

print("Test hindi decoding:", hi_lang.decodeSentence((hi_lang.encodeSentence("शुक्रिया", 10))))
Test en encoding: [15, 7, 10, 13, 9, 6, 20, 4]
Test en decoding: shukriya
Test hindi encoding: [35, 19, 15, 22, 5, 12, 21, 4, 2, 0]
Test hindi decoding: शुक्रिया
VE = len(en_lang.word2id)
VH = len(hi_lang.word2id)

The Seq2Seq architecture

  • We will implement a seq2seq architecture for transliteration in Tensorflow r1.13.1 / r1.14
  • Debugging Tip: Always keep track of tensor dimensions!
  • Tensorflow Computation Graph - We will build a tf computation graph first. This is the representation used by tf for any neural network architecture. Once the computation graph is built, you can feed data to it for training or inference

Character Embedding Matrix

en_word_emb_matrix = tf.get_variable("en_word_emb_matrix", (VE, 300), dtype=tf.float32)
hi_word_emb_matrix = tf.get_variable("hi_word_emb_matrix", (VH, 300), dtype=tf.float32)
WARNING:tensorflow:From /home/bishal/miniconda3/envs/tf_gpu/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.

Placeholders

  • Input to a tensorflow graph is
keep_prob = tf.placeholder(tf.float32)

input_ids = tf.placeholder(tf.int32, (None, MAX_SEQ_LEN))
input_lens = tf.placeholder(tf.int32, (None, ))

ph_target_ids = tf.placeholder(tf.int32, (None, MAX_SEQ_LEN))
target_lens = tf.placeholder(tf.int32, (None, ))
# Add SOS or GO symbol
target_ids = tf.concat([tf.fill([BATCH_SIZE,1], hi_lang.isos), ph_target_ids], -1)

Building the computation graph

input_emb = tf.nn.embedding_lookup(en_word_emb_matrix, input_ids)
target_emb = tf.nn.embedding_lookup(hi_word_emb_matrix, target_ids[:, :-1])
input_emb.shape
TensorShape([Dimension(None), Dimension(20), Dimension(300)])

Encoder - RNN based sequence encoder

encoder_cell = tf.nn.rnn_cell.GRUCell(128) # 128 is the dimension of hidden state
encoder_cell = DropoutWrapper(encoder_cell, output_keep_prob=keep_prob) # Adding Dropout for regularization
WARNING:tensorflow:From <ipython-input-18-73573fbda69f>:1: GRUCell.__init__ (from tensorflow.python.ops.rnn_cell_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This class is equivalent as tf.keras.layers.GRUCell, and will be replaced by that in Tensorflow 2.0.
enc_outputs, enc_state = tf.nn.dynamic_rnn(
    encoder_cell, # The encoder GRU cell
    input_emb, # Embedded input sequence
    sequence_length=input_lens, # Sequence lengths of individual inputs in a batch
    initial_state=encoder_cell.zero_state(BATCH_SIZE, dtype=tf.float32)
)
WARNING:tensorflow:From <ipython-input-19-1e698e82f454>:5: dynamic_rnn (from tensorflow.python.ops.rnn) is deprecated and will be removed in a future version.
Instructions for updating:
Please use `keras.layers.RNN(cell)`, which is equivalent to this API
WARNING:tensorflow:From /home/bishal/miniconda3/envs/tf_gpu/lib/python3.6/site-packages/tensorflow/python/ops/rnn.py:626: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.cast instead.
WARNING:tensorflow:From /home/bishal/miniconda3/envs/tf_gpu/lib/python3.6/site-packages/tensorflow/python/ops/rnn_cell_impl.py:1259: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
# Confirm the shape of the final hidden state
enc_state.shape
TensorShape([Dimension(64), Dimension(128)])

Decoder

decoder_cell = tf.nn.rnn_cell.GRUCell(128)
decoder_cell = DropoutWrapper(decoder_cell, output_keep_prob=keep_prob)

Decoder to Output Vocab Projection Layer

output_projection = tf.layers.Dense(len(hi_lang.word2id))

Decoder Training Helper

helper = seq2seq.TrainingHelper(target_emb, target_lens)
decoder = seq2seq.BasicDecoder(decoder_cell, helper, enc_state, output_projection)
outputs, _, outputs_lens = seq2seq.dynamic_decode(decoder, maximum_iterations=MAX_SEQ_LEN, 
                                                  impute_finished=False, swap_memory=True)
output_max_len = tf.reduce_max(outputs_lens)

And Decoder Inference Helper

# Using the decoder_cell without dropout here.
infer_helper = seq2seq.GreedyEmbeddingHelper(hi_word_emb_matrix, tf.fill([BATCH_SIZE, ], hi_lang.isos), hi_lang.ieos)
infer_decoder = seq2seq.BasicDecoder(decoder_cell, infer_helper, enc_state, output_projection)
infer_output = seq2seq.dynamic_decode(infer_decoder, maximum_iterations=MAX_SEQ_LEN, swap_memory=True)

Loss and Optimizers

# Sequence mask:
# To make sure we don't back-propagate error from output of length positions
masks = tf.sequence_mask(target_lens, output_max_len, dtype=tf.float32, name='masks')

# Loss function - weighted softmax cross entropy
cost = seq2seq.sequence_loss(
    outputs[0],
    target_ids[:, 1:(output_max_len + 1)],
    masks)

# Optimizer
optimizer = tf.train.AdamOptimizer(0.0001)
train_op = optimizer.minimize(cost)
init = tf.global_variables_initializer()

Tensorflow Sessions

sess_config = tf.ConfigProto()
sess_config.gpu_options.allow_growth = True
sess = tf.InteractiveSession(config=sess_config)
sess.run(init)

Minibatch Training + Validation

  • Performance Evaluation using BLEU scores
random.seed(41)
parallel = list(zip(en_sentences, hi_sentences))
random.shuffle(parallel)
parallel[1000]
('hazaarii', 'हज़ारी')
train_n = int(0.95*N)
valid_n = N - train_n
train_pairs = parallel[:train_n].copy()
valid_pairs = parallel[train_n:]
def small_test():
    all_bleu = []
    smoothing = nltk.translate.bleu_score.SmoothingFunction().method7
    for m in range(0, valid_n, BATCH_SIZE):
        # print(f"Status: {m}/{N}", end='\r')
        n = m + BATCH_SIZE
        if n > valid_n:
            # print("Epoch Complete...")
            break

        input_batch = np.zeros((BATCH_SIZE, MAX_SEQ_LEN), dtype=np.int32)
        input_lens_batch = np.zeros((BATCH_SIZE,), dtype=np.int32)
        for i in range(m, n):
            b,a = en_lang.encodeSentence2(valid_pairs[i][0], MAX_SEQ_LEN)
            input_batch[i-m,:] = a
            input_lens_batch[i-m] = b

    #     target_batch = np.zeros((BATCH_SIZE, MAX_SEQ_LEN), dtype=np.int32)
    #     target_lens_batch = np.zeros((BATCH_SIZE,), dtype=np.int32)
    #     for i in range(m, n):
    #         b,a = hi_lang.encodeSentence2(valid_pairs[i][1], MAX_SEQ_LEN)
    #         target_batch[i-m,:] = a
    #         target_lens_batch[i-m] = b

        feed_dict={
            input_ids: input_batch,
            input_lens: input_lens_batch,
            #target_ids: target_batch,
            #target_lens: target_lens_batch,
            keep_prob: 1.0
        }
        pred_batch = sess.run(infer_output[0].sample_id, feed_dict=feed_dict)
        for k, pred_ in enumerate(pred_batch):
            pred_s = hi_lang.decodeSentence(list(pred_))
            ref = valid_pairs[m+k][1]
            try:
                _bx = nltk.translate.bleu_score.sentence_bleu(
                    [ref],
                    pred_s,
                    weights=[1/4]*4,
                    smoothing_function=smoothing)
            except ZeroDivisionError:
                _bx = 0
            all_bleu.append(_bx)

    print(f"BLEU Score: {np.mean(all_bleu)}")
for _e in range(20):
    # Mix things up a bit.
    random.shuffle(train_pairs)
    pbar = tqdm_notebook(range(0, train_n, BATCH_SIZE))
    batch_loss = 0
    bxi = 0
    for m in pbar:
        n = m + BATCH_SIZE
        if n <= train_n:
            # print("Epoch Complete... \n")

            input_batch = np.zeros((BATCH_SIZE, MAX_SEQ_LEN), dtype=np.int32)
            input_lens_batch = np.zeros((BATCH_SIZE,), dtype=np.int32)
            for i in range(m, n):
                b,a = en_lang.encodeSentence2(train_pairs[i][0], MAX_SEQ_LEN)
                input_batch[i-m,:] = a
                input_lens_batch[i-m] = b

            target_batch = np.zeros((BATCH_SIZE, MAX_SEQ_LEN), dtype=np.int32)
            target_lens_batch = np.zeros((BATCH_SIZE,), dtype=np.int32)
            for i in range(m, n):
                b,a = hi_lang.encodeSentence2(train_pairs[i][1], MAX_SEQ_LEN)
                target_batch[i-m,:] = a
                target_lens_batch[i-m] = b

            feed_dict={
                input_ids: input_batch,
                input_lens: input_lens_batch,
                ph_target_ids: target_batch,
                target_lens: target_lens_batch,
                keep_prob: 0.8 
            }
            sess.run(train_op, feed_dict=feed_dict)
            batch_loss += sess.run(cost, feed_dict=feed_dict)
            pbar.set_description(f"Epoch: {_e} >> Loss: {batch_loss/(bxi+1):2.2F}:")
            bxi += 1
            if (1 + n//BATCH_SIZE) % 100 == 0:
                small_test()
HBox(children=(IntProgress(value=0, max=458), HTML(value='')))


BLEU Score: 0.610158884519212
BLEU Score: 0.6146452675497532
BLEU Score: 0.6058223461268137
BLEU Score: 0.6131855443159836




HBox(children=(IntProgress(value=0, max=458), HTML(value='')))


BLEU Score: 0.6182856361594138
BLEU Score: 0.620302774263697
BLEU Score: 0.6255745267126785
BLEU Score: 0.6288471674183466

Let’s see some real translation examples now!

def transliterate(s):
    input_batch = np.zeros((BATCH_SIZE, MAX_SEQ_LEN), dtype=np.int32)
    input_lens_batch = np.zeros((BATCH_SIZE,), dtype=np.int32)
    b,a = en_lang.encodeSentence2(s, MAX_SEQ_LEN)
    input_batch[0, :] = a
    input_lens_batch[0] = b
    
    feed_dict={
        input_ids: input_batch,
        input_lens: input_lens_batch,
        #target_ids: target_batch,
        #target_lens: target_lens_batch,
        keep_prob: 1.0
    }
    pred_batch = sess.run(infer_output[0].sample_id, feed_dict=feed_dict)
    pred_ = pred_batch[0]
    pred_s = hi_lang.decodeSentence(list(pred_))
    # ref = valid_pairs[m+k][1]
    return pred_s
transliterate("saya")
'साया'
transliterate("raama")
'रामा'
transliterate("shahar")
'शाहर'