from abc import ABC, abstractmethod
import inspect
import pickle
import numpy as np
from joblib import Parallel, delayed, cpu_count
from .channels import Channel
from .messages import unpack_to_bits, pack_to_dec, generate_data
from .encoders import PolarEncoder, Encoder, PolarWiretapEncoder
from .modulators import Modulator
def _logdomain_sum(x, y):
if x < y:
z = y + np.log1p(np.exp(x-y))
else:
z = x + np.log1p(np.exp(y-x))
return z
def _logdomain_sum_multiple(x, y):
_logpart = np.log1p(np.exp(-np.abs(x-y)))
z = np.maximum(x, y) + _logpart
return z
[docs]class Decoder(ABC):
"""Abstract decoder class."""
def __init__(self, code_length, info_length, base=2, parallel=True):
self.code_length = code_length
self.info_length = info_length
self.base = base
self.parallel = parallel
[docs] @abstractmethod
def decode_messages(self, messages, channel=None): pass
[docs]class IdentityDecoder(Decoder):
"""Identity decoder. Simply returns the input."""
[docs] @staticmethod
def decode_messages(messages, channel=None):
return messages
[docs]class RepetitionDecoder(Decoder):
def __init__(self, *args, **kwargs): pass
[docs] @staticmethod
def decode_messages(messages, channel=None):
decoded = np.zeros((len(messages), 1))
for idx, message in enumerate(messages):
val, counts = np.unique(message, return_counts=True)
_decision = np.argmax(counts)
decoded[idx] = val[_decision]
return decoded
[docs]class LinearDecoder(Decoder):
"""Linear block decoder.
Parameters
----------
TODO
"""
[docs] def decode_messages(self, messages, channel=None):
raise NotImplementedError()
[docs]class PolarDecoder(Decoder):
"""Polar code decoder. Taken from **polarcodes.com**
The decoder for BAWGN channels expects a channel output of noisy codewords
which are modulated to +1 and -1.
Parameters
----------
code_length : int
Length of the code.
info_length : int
Length of the messages.
design_channel : str or Channel
Name of the used channel. Valid choices are currently "BAWGN" and "BSC".
design_channelstate : float, optional
State of the design channel. For "BAWGN" channels, this corresponds to
the SNR value in dB. For "BSC" channels, this corresponds to the
bit-flip probability.
pos_lookup : array, optional
Position lookup of the polar code, where -1 indicates message bits,
while 0 and 1 denote the frozenbits.
frozenbits : array, optional
Bits used for the frozen bit positions. This is ignored, if `pos_lookup`
is provided.
parallel : bool, optional
If True, parallel processing is used. This might not be available on
all machines and causes higher use of system resources.
"""
def __init__(self, code_length, info_length, design_channel,
design_channelstate=0., pos_lookup=None, frozenbits=None,
parallel=True, **kwargs):
if isinstance(design_channel, Channel):
channel_name = design_channel.name
design_channelstate = design_channel.get_channelstate()
else:
channel_name = design_channel
self.design_channel = channel_name
self.design_channelstate = design_channelstate
if pos_lookup is None:
self.pos_lookup = PolarEncoder.construct_polar_code(
code_length, info_length, design_channel, design_channelstate,
frozenbits)
else:
self.pos_lookup = np.array(pos_lookup)
self.rev_index = self._reverse_index(code_length)
self.idx_first_one = self._index_first_num_from_msb(code_length, 1)
self.idx_first_zero = self._index_first_num_from_msb(code_length, 0)
super().__init__(code_length, info_length, parallel=parallel)
@staticmethod
def _reverse_index(code_length):
_n = int(np.ceil(np.log2(code_length)))
rev_idx = [pack_to_dec(np.flip(unpack_to_bits([idx], _n), axis=1))[0][0]
for idx in range(code_length)]
return rev_idx
@staticmethod
def _index_first_num_from_msb(code_length, number):
_n = int(np.ceil(np.log2(code_length)))
idx_list = np.zeros(code_length)
for idx in range(code_length):
idx_bin = unpack_to_bits([idx], _n)[0]
try:
last_level = np.where(idx_bin == number)[0][0]
except IndexError:
last_level = _n-1
idx_list[idx] = last_level
return idx_list
[docs] def decode_messages(self, messages, channel=None):
"""Decode polar encoded messages.
Parameters
----------
messages : array
Array of received (noisy) codewords which were created by polar
encoding messages. Each row represents one received word.
channel : float or Channel, optional
This can either be a channel state, e.g., SNR in an AWGN channel,
of the channel model used for constructing the decoder or a
`channels.Channel` object.
If None, the design parameters are used.
Returns
-------
decoded_messages : array
Array containing the estimated messages after decoding the channel
output.
"""
#decoded = np.zeros((len(messages), self.info_length))
decoded = np.zeros((len(messages), self.code_length))
channel_name = self.design_channel
if channel is None:
channel_state = self.design_channelstate
elif isinstance(channel, Channel):
channel_name = channel.name
if channel_name != self.design_channel:
Warning("The channel you passed for decoding ('{}') is different "
"to the one you used for constructing the decoder ('{}')!"
.format(channel_name, self.design_channel))
channel_state = channel.get_channelstate()
else:
channel_state = channel
if channel_name == "BAWGN":
snr = 10**(channel_state/10.)
initial_llr = -2*np.sqrt(2*(self.info_length/self.code_length)*snr)*messages
#if self.parallel:
# num_cores = cpu_count()
# decoded = Parallel(n_jobs=num_cores)(
# delayed(self._polar_llr_decode)(k) for k in initial_llr)
# decoded = np.array(decoded)
#else:
# for idx, _llr_codeword in enumerate(initial_llr):
# decoded[idx] = self._polar_llr_decode(_llr_codeword)
decoded = self._polar_llr_decode_multiple(initial_llr)
elif channel_name == "BSC":
llr = np.log(channel_state) - np.log(1-channel_state)
initial_llr = (2*messages - 1) * llr
if self.parallel:
num_cores = cpu_count()
decoded = Parallel(n_jobs=num_cores)(
delayed(self._polar_llr_decode)(k) for k in initial_llr)
decoded = np.array(decoded)
else:
for idx, _llr_codeword in enumerate(initial_llr):
decoded[idx] = self._polar_llr_decode(_llr_codeword)
decoded = self._get_info_bit_positions(decoded)
return decoded
def _get_info_bit_positions(self, decoded):
return decoded[:, self.pos_lookup == -1]
def _polar_llr_decode(self, initial_llr):
llr = np.zeros(2*self.code_length-1)
llr[self.code_length-1:] = initial_llr
bit_branch = np.zeros((2, self.code_length-1))
decoded = np.zeros(self.code_length)
for j in range(self.code_length):
rev_idx = self.rev_index[j]
llr = self._update_llr(llr, bit_branch, rev_idx)
if self.pos_lookup[rev_idx] <= -1:
if llr[0] > 0:
decoded[rev_idx] = 0
else:
decoded[rev_idx] = 1
else:
decoded[rev_idx] = self.pos_lookup[rev_idx]
bit_branch = self._update_bit_branch(decoded[rev_idx], rev_idx, bit_branch)
#return decoded[self.pos_lookup == -1]
return decoded
def _update_llr(self, llr, bit_branch, rev_idx):
_n = int(np.ceil(np.log2(self.code_length)))
if rev_idx == 0:
next_level = _n
else:
last_level = int(self.idx_first_one[rev_idx]+1)
st = int(2**(last_level-1))
ed = int(2**(last_level)-1)
for idx in range(st-1, ed):
llr[idx] = self._lowerconv(
bit_branch[0, idx], llr[ed+2*(idx+1-st)], llr[ed+2*(idx+1-st)+1])
next_level = last_level - 1
for level in np.arange(next_level, 0, -1):
st = int(2**(level-1))
ed = int(2**(level) - 1)
for idx in range(st-1, ed):
llr[idx] = self._upperconv(llr[ed+2*(idx+1-st)], llr[ed+2*(idx+1-st)+1])
return llr
@staticmethod
def _lowerconv(upper_decision, upper_llr, lower_llr):
if upper_decision == 0:
llr = lower_llr + upper_llr
else:
llr = lower_llr - upper_llr
return llr
@staticmethod
def _upperconv(llr1, llr2):
llr = _logdomain_sum(llr1+llr2, 0) - _logdomain_sum(llr1, llr2)
return llr
def _update_bit_branch(self, bit, rev_idx, bit_branch):
_n = int(np.ceil(np.log2(self.code_length)))
if rev_idx == self.code_length-1:
return
elif rev_idx < self.code_length/2:
bit_branch[0, 0] = bit
else:
last_level = int(self.idx_first_zero[rev_idx]+1)
bit_branch[1, 0] = bit
for level in range(1, last_level-2+1):
st = int(2**(level-1))
ed = int(2**(level)-1)
for idx in range(st-1, ed):
bit_branch[1, ed+2*(idx+1-st)] = np.mod(bit_branch[0, idx]+bit_branch[1, idx], 2)
bit_branch[1, ed+2*(idx+1-st)+1] = bit_branch[1, idx]
level = last_level-1
st = int(2**(level-1))
ed = int(2**(level)-1)
for idx in range(st-1, ed):
bit_branch[0, ed+2*(idx+1-st)] = np.mod(bit_branch[0, idx]+bit_branch[1, idx], 2)
bit_branch[0, ed+2*(idx+1-st)+1] = bit_branch[1, idx]
return bit_branch
#####
def _polar_llr_decode_multiple(self, initial_llr):
llr = np.zeros((len(initial_llr), 2*self.code_length-1))
llr[:, self.code_length-1:] = initial_llr
bit_branch = np.zeros((len(initial_llr), 2, self.code_length-1))
decoded = np.zeros((len(initial_llr), self.code_length))
for j in range(self.code_length):
rev_idx = self.rev_index[j]
llr = self._update_llr_multiple(llr, bit_branch, rev_idx)
if self.pos_lookup[rev_idx] <= -1:
#decoded[:, rev_idx] = 0
_idx = np.where(llr[:, 0] <= 0)[0]
decoded[_idx, rev_idx] = 1
else:
decoded[:, rev_idx] = self.pos_lookup[rev_idx]
bit_branch = self._update_bit_branch_multiple(
decoded[:, rev_idx], rev_idx, bit_branch)
#return decoded[self.pos_lookup == -1]
return decoded
def _update_llr_multiple(self, llr, bit_branch, rev_idx):
_n = int(np.ceil(np.log2(self.code_length)))
if rev_idx == 0:
next_level = _n
else:
last_level = int(self.idx_first_one[rev_idx]+1)
st = int(2**(last_level-1))
ed = int(2**(last_level)-1)
for idx in range(st-1, ed):
llr[:, idx] = self._lowerconv_multiple(
bit_branch[:, 0, idx], llr[:, ed+2*(idx+1-st)], llr[:, ed+2*(idx+1-st)+1])
next_level = last_level - 1
for level in np.arange(next_level, 0, -1):
st = int(2**(level-1))
ed = int(2**(level) - 1)
for idx in range(st-1, ed):
llr[:, idx] = self._upperconv_multiple(
llr[:, ed+2*(idx+1-st)], llr[:, ed+2*(idx+1-st)+1])
return llr
def _update_bit_branch_multiple(self, bit, rev_idx, bit_branch):
_n = int(np.ceil(np.log2(self.code_length)))
if rev_idx == self.code_length-1:
return
elif rev_idx < self.code_length/2:
bit_branch[:, 0, 0] = bit
else:
last_level = int(self.idx_first_zero[rev_idx]+1)
bit_branch[:, 1, 0] = bit
for level in range(1, last_level-2+1):
st = int(2**(level-1))
ed = int(2**(level)-1)
for idx in range(st-1, ed):
bit_branch[:, 1, ed+2*(idx+1-st)] = np.mod(bit_branch[:, 0, idx]+bit_branch[:, 1, idx], 2)
bit_branch[:, 1, ed+2*(idx+1-st)+1] = bit_branch[:, 1, idx]
level = last_level-1
st = int(2**(level-1))
ed = int(2**(level)-1)
for idx in range(st-1, ed):
bit_branch[:, 0, ed+2*(idx+1-st)] = np.mod(bit_branch[:, 0, idx]+bit_branch[:, 1, idx], 2)
bit_branch[:, 0, ed+2*(idx+1-st)+1] = bit_branch[:, 1, idx]
return bit_branch
@staticmethod
def _lowerconv_multiple(upper_decision, upper_llr, lower_llr):
llr = lower_llr - upper_llr
idx = np.where(upper_decision == 0)
llr[idx] = lower_llr[idx] + upper_llr[idx]
return llr
@staticmethod
def _upperconv_multiple(llr1, llr2):
llr = _logdomain_sum_multiple(llr1+llr2, 0) - _logdomain_sum_multiple(llr1, llr2)
return llr
####
[docs]class PolarWiretapDecoder(PolarDecoder):
"""Decoder class for decoding polar wiretap codes.
You can either provide both channels (to Bob and Eve) or provide the main
channel to Bob and the position lookup of the already constructed code.
Parameters
----------
code_length : int
Length of the codewords.
design_channel_bob : str
Channel name of the main channel to Bob. Valid choices are the channel
models which are supported by the PolarDecoder.
design_channel_eve : str, optional
Channel name of the side channel to Eve. Valid choices are the channel
models which are supported by the PolarEncoder.
design_channelstate_bob : float, optional
Channelstate of the main channel.
design_channelstate_eve : float, optional
Channelstate of the side channel.
pos_lookup : array, optional
Position lookup of the constructed wiretap code. If this is provided,
no additional code is constructed and the values of Eve's channel are
ignored.
"""
def __init__(self, code_length, design_channel_bob, design_channel_eve=None,
design_channelstate_bob=0, design_channelstate_eve=0.,
pos_lookup=None, frozenbits=None, parallel=True,
info_length_bob=None, random_length=None, **kwargs):
if pos_lookup is None:
pos_lookup = PolarWiretapEncoder.construct_polar_wiretap_code(
code_length, design_channel_bob, design_channel_eve,
design_channelstate_bob, design_channelstate_eve, frozenbits,
info_length_bob, random_length)
info_length = np.count_nonzero(pos_lookup == -1)
info_length_bob = np.count_nonzero(pos_lookup < 0)
super().__init__(code_length, info_length, design_channel_bob,
design_channelstate=design_channelstate_bob,
pos_lookup=pos_lookup, frozenbits=frozenbits,
parallel=parallel, **kwargs)