# wavenet

Still need to expand my thoughts here.

import torch
import matplotlib.pyplot as plt
%matplotlib inline

torch.manual_seed(101)

words = open('names.txt', 'r').read().splitlines()
[training_set, validation_set, test_set] = torch.utils.data.random_split(words, [.8 ,.1 ,.1])

TOKEN='.'
chars = list(set(''.join(words)+TOKEN))
chars.sort()
itos = dict([(i,s) for i, s in enumerate(chars)])
stoi = dict([s, i] for i, s in enumerate(chars))
vocab_size = len(chars)
block_size = 8

def build_dataset(words):
    xs = []
    ys = []
    for word in words:
        context = [0]*block_size
        for i, ch in enumerate(word+TOKEN):
            ix = stoi[ch]
            xs.append(context)
            ys.append(ix)
            context = context[1:]+[ix]
    return torch.tensor(xs), torch.tensor(ys)

Xtr, Ytr = build_dataset(training_set)
Xval, Yval = build_dataset(validation_set)
Xtest, Ytest = build_dataset(test_set)

n_embd = 24
n_hidden = 128

# Departs from pytorch implementation https://youtu.be/t3YJ5hKiMQ0?si=dIVp0YFYBFH0NpdS&t=2620
class BatchNorm1d(torch.nn.Module):
  def __init__(self, dim, eps=1e-5, momentum=0.1):
    super(BatchNorm1d, self).__init__()
    self.eps = eps
    self.momentum = momentum
    self.training = True
    # parameters (trained with backprop)
    self.gamma = torch.ones(dim)
    self.beta = torch.zeros(dim)
    # buffers (trained with a running 'momentum update')
    self.running_mean = torch.zeros(dim)
    self.running_var = torch.ones(dim)

  def __call__(self, x):
    # calculate the forward pass
    if self.training:
      if x.ndim == 2:
        dim = 0
      elif x.ndim == 3:
        dim = (0,1)
      xmean = x.mean(dim, keepdim=True) # batch mean
      xvar = x.var(dim, keepdim=True) # batch variance
    else:
      xmean = self.running_mean
      xvar = self.running_var
    xhat = (x - xmean) / torch.sqrt(xvar + self.eps) # normalize to unit variance
    self.out = self.gamma * xhat + self.beta
    # update the buffers
    if self.training:
      with torch.no_grad():
        self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * xmean
        self.running_var = (1 - self.momentum) * self.running_var + self.momentum * xvar
    return self.out

  def parameters(self):
    return [self.gamma, self.beta]


class FlattenConsecutive(torch.nn.Module):
    def __init__(self, n):
      super(FlattenConsecutive, self).__init__()
      self.n = n

    def __call__(self, x):
        num_batches, num_characters, num_embeddings = x.shape
        x = x.view(num_batches, num_characters//self.n, num_embeddings*self.n)
        if x.shape[1] == 1:
          x = x.squeeze(1)
        self.out = x
        return self.out

model = torch.nn.Sequential(
    torch.nn.Embedding(vocab_size, n_embd),
    FlattenConsecutive(2), torch.nn.Linear(n_embd * 2, n_hidden, bias=False), BatchNorm1d(n_hidden), torch.nn.Tanh(),
    FlattenConsecutive(2), torch.nn.Linear(n_hidden * 2, n_hidden, bias=False), BatchNorm1d(n_hidden), torch.nn.Tanh(),
    FlattenConsecutive(2), torch.nn.Linear(n_hidden * 2, n_hidden, bias=False), BatchNorm1d(n_hidden), torch.nn.Tanh(),
    torch.nn.Linear(n_hidden, vocab_size)
)

with torch.no_grad():
    model[-1].weight.data *= .1

print(sum(p.nelement() for p in model.parameters()))
# 75811

for p in model.parameters():
  p.requires_grad = True

max_steps = 25000
batch_size = 32
for i in range(max_steps):
    ix = torch.randint(0, Xtr.shape[0], (batch_size, ))
    Xb, Yb = Xtr[ix], Ytr[ix]
    # forward pass
    x = model(Xb)
    loss = torch.nn.functional.cross_entropy(x, Yb)

    for p in model.parameters():
        p.grad = None
    loss.backward()

    lr = 0.1 if i<(max_steps*.75) else .01
    for p in model.parameters():
        p.data += -lr * p.grad

    if i%(max_steps/10) == 0:
        print(f"{i:7d}/{max_steps:7d}: loss={loss.item(): .4f}")

#       0/  25000: loss= 3.2859
#    2500/  25000: loss= 2.5645
#    5000/  25000: loss= 2.4274
#    7500/  25000: loss= 2.3042
#   10000/  25000: loss= 2.1279
#   12500/  25000: loss= 2.0867
#   15000/  25000: loss= 1.9211
#   17500/  25000: loss= 2.0363
#   20000/  25000: loss= 2.0193
#   22500/  25000: loss= 1.9068

model.eval()

# evaluate the loss
@torch.no_grad() # this decorator disables gradient tracking inside pytorch
def split_loss(split):
  x,y = {
    'train': (Xtr, Ytr),
    'val': (Xval, Yval),
    'test': (Xtest, Ytest),
  }[split]
  logits = model(x)
  loss = torch.nn.functional.cross_entropy(logits, y)
  print(split, loss.item())

split_loss('train')
split_loss('val')
split_loss('test')

# train 1.9855728149414062
# val 2.0416152477264404
# test 2.034022331237793

num_examples = 10
for _ in range(num_examples):
    CONTEXT = [0] * block_size
    out = []
    while True:
        x = model(torch.tensor([CONTEXT]))
        probs = torch.softmax(x, 1)
        ix = torch.multinomial(probs, num_samples=1, replacement=True).item()
        if ix == 0:
            break
        CONTEXT = CONTEXT[1:] + [ix]
        out.append(ix)
    print(''.join([itos[c] for c in out]))

# daorbose
# elinnah
# xaras
# bryston
# ambiya
# raliya
# zahaa
# eyin
# laui
# riyvanna