Optimize an MZI in the time domain to perform the XOR on two subsequent bits

Imports

[1]:
import torch
import numpy as np
import photontorch as pt
import matplotlib.pyplot as plt

from ipywidgets import interact
from tqdm.notebook import trange

Parameters & Settings

[2]:
c = 299792458.0 #[m/s] speed of light
Rb = 20e9 #[1/s] bitrate
Rs = 320e9 #[1/s] samplerate
fc = 20e9 #[1/s] cutoff frequency for bit generation
wl0 = 1550e-9 #[m] center wavelength
neff = 2.34 #[1] effective index
ng = 3.4 #[1] group index
N = 100 #[1] number of bits per bit stream
B = 10 #[1] number of bit streams in training batch
Str, Sva, Ste = (1, 2, 3) #[1] train seed, validation seed, test seed
Lr = (1/Rb) * (c/ng) #[m] reference length (distance the signal travels during one bit period)
device = torch.device("cpu") # default device used
torch.set_default_dtype(torch.float32) # default dtype used
torch.manual_seed(123) # pytorch seed
np.random.seed(42) # numpy seed

Bit generator

[3]:
class BitGenerator(np.random.RandomState):
    def __call__(self, N, B=1):
        """ generate a bit sequence

        Args:
            N: number of bits per sequence
            B: number of sequences in batch
        """
        bits = self.rand(N,B) > 0.5
        return bits

Bit operations

[4]:
def COPY(bits):
    return bits.copy()
def XOR(bits):
    return (np.concatenate([bits, np.zeros_like(bits[-1:])], 0) ^ np.concatenate([np.zeros_like(bits[:1]), bits], 0))[:-1]
def AND(bits):
    return (np.concatenate([bits, np.zeros_like(bits[-1:])], 0) & np.concatenate([np.zeros_like(bits[:1]), bits], 0))[:-1]

TARGET = XOR

Generate Streams

[5]:
bits = BitGenerator(seed=23)(10)
streamgen = pt.BitStreamGenerator(bitrate=Rb, samplerate=Rs, cutoff_frequency=fc, device=device)
stream = streamgen(bits)
target = streamgen(TARGET(bits))
t = np.arange(stream.shape[0])*(1/Rs)
plt.plot(1e9*t, stream.data.cpu().numpy())
plt.plot(1e9*t, target.data.cpu().numpy())
plt.xlabel("t [ns]")
plt.show()
../_images/examples_09_XOR_task_with_MZI_10_0.png
[6]:
# train streams
trbits = BitGenerator(seed=Str)(N, B) # N bits per train stream, B train streams per batch.
trstream = streamgen(trbits)
trtarget = streamgen(TARGET(trbits))**2
t = np.arange(trstream.shape[0])*(1/Rs)

# validation streams
vabits = BitGenerator(seed=Sva)(N, 1) # N bits per validation stream, 1 validation stream per batch.
vastream = streamgen(vabits)
vatarget = streamgen(TARGET(vabits))**2

# test streams
tebits = BitGenerator(seed=Ste)(N*B, 1) # N*B bits per teststream
testream = streamgen(tebits)
tetarget = streamgen(TARGET(tebits))**2

Simple MZI Circuit

[7]:
class Circuit(pt.Network):
    def __init__(self):
        super().__init__()
        self.src = pt.Source()
        self.det = pt.Detector()
        self.trm1 = self.trm2 = pt.Term()
        self.dc1 = pt.DirectionalCoupler(coupling=np.random.rand(), trainable=True)
        self.dc2 = pt.DirectionalCoupler(coupling=np.random.rand(), trainable=True)
        self.wg1 = pt.Waveguide(length=2*Lr, phase=2*np.pi*np.random.rand(), loss=0.0, trainable=True)
        self.wg2 = pt.Waveguide(length=1*Lr, phase=2*np.pi*np.random.rand(), loss=0.0, trainable=True)
        self.link("trm1:0", "3:dc1:2", "0:wg1:1", "3:dc2:2", "0:trm2")
        self.link( "src:0", "0:dc1:1", "0:wg2:1", "0:dc2:1", "0:det")

Training

initial plot (to find a good latency)

[8]:
nw = Circuit()
with pt.Environment(t=t):
    detected = nw(source=vastream.rename("t", "b"))[:,0,0,:] # lower dimensional sources should have named dimensions

mse = pt.MSELoss(latency=0.0, warmup=0.0, bitrate=Rb, samplerate=Rs)
@interact(latency=(0,3,0.1))
def _(latency=1.0):
    mse.plot(vastream, label="input")
    mse.plot(vatarget, label="target")
    mse.plot(detected, latency=latency, label="detected")
    plt.legend()
    plt.xlim(0,1)
    plt.show()
[9]:
mse = pt.MSELoss(latency=1.0, warmup=0.0, bitrate=Rb, samplerate=Rs)
ber = pt.BERLoss(latency=1.0, threshold=0.5, warmup=0.0, bitrate=Rb, samplerate=Rs)
[10]:
scale = torch.tensor(1.0, requires_grad=True)
bias = torch.tensor(0.0, requires_grad=True)
optim = torch.optim.Adam([scale, bias, *nw.parameters()], lr=0.1)
range_ = trange(40)
with pt.Environment(t=t, enable_grad=True):
    for i in range_:
        optim.zero_grad()
        trdet = nw(source=trstream.rename("t", "b"))[:,0,0,:]*scale + bias
        loss = mse(trdet, trtarget)
        loss.backward()
        optim.step()
        range_.set_postfix(loss=loss.item(), ber=ber(trdet, trtarget))

Let’s take a look at the final result:

[11]:
with pt.Environment(t=t):
    vadet = nw(source=vastream.rename("t", "b"))[:,0,0,:]
    ber_err = ber(vadet, vatarget, threshold=float((0.5-bias)/scale))
    mse_err = mse(scale*vadet+bias, vatarget)

    print(f"ber:\t{ber_err}")
    print(f"mse:\t{mse_err.item()}")

    mse.plot(vadet)
    mse.plot(vatarget, latency=0.0)
    plt.xlim(-0,1.15)
    plt.xlabel("t [ps]")
    plt.show()
ber:    0.0
mse:    0.057594358921051025
../_images/examples_09_XOR_task_with_MZI_20_1.png