class HiDDeN(BaseCodec):
def __init__(
self,
ckpt: Optional[str] = "ckpt/hidden/model.pt",
configuration: HiDDenConfig = HIDDEN_DEFAULT_CFG,
cuda: bool = True,
) -> None:
super(HiDDeN, self).__init__()
self.config = configuration
device = torch.device("cuda" if cuda and torch.cuda.is_available() else "cpu")
self.device = device
self.encoder = Encoder(configuration).to(device)
self.decoder = Decoder(configuration).to(device)
if ckpt is not None:
checkpoint = torch.load(ckpt, map_location=device)
self.encoder.load_state_dict(checkpoint["encoder"])
self.decoder.load_state_dict(checkpoint["decoder"])
self.transform = transforms.Compose(
[
transforms.CenterCrop((self.config.H, self.config.W)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
]
)
self.inv_transform = transforms.Compose(
[
transforms.Normalize([-1, -1, -1], [2, 2, 2]),
transforms.ToPILImage(),
]
)
def _transform(self, image: Image.Image) -> torch.Tensor:
return self.transform(image).unsqueeze(0) # type: ignore
def _inv_transform(self, image: torch.Tensor) -> Image.Image:
return self.inv_transform(image.squeeze(0)) # type: ignore
def encode(self, carrier: Image.Image, payload: torch.Tensor) -> Image.Image:
image = self._transform(carrier).to(self.device)
# message = torch.from_numpy(np.array(payload)).to(self.device)
encoded = self.encoder(image, payload)
return self._inv_transform(encoded.detach().cpu())
def decode(self, carrier: Image.Image) -> np.ndarray:
image = self._transform(carrier).to(self.device)
decoded = self.decoder(image)
return decoded.detach().cpu().numpy().round().clip(0, 1).astype(np.uint8)