Skip to content

HiDDeN

stegobox.codec.HiDDeN

Bases: BaseCodec

Source code in stegobox/codec/hidden/model.py
class HiDDeN(BaseCodec):
    def __init__(
        self,
        ckpt: Optional[str] = "ckpt/hidden/model.pt",
        configuration: HiDDenConfig = HIDDEN_DEFAULT_CFG,
        cuda: bool = True,
    ) -> None:
        super(HiDDeN, self).__init__()
        self.config = configuration

        device = torch.device("cuda" if cuda and torch.cuda.is_available() else "cpu")
        self.device = device

        self.encoder = Encoder(configuration).to(device)
        self.decoder = Decoder(configuration).to(device)

        if ckpt is not None:
            checkpoint = torch.load(ckpt, map_location=device)
            self.encoder.load_state_dict(checkpoint["encoder"])
            self.decoder.load_state_dict(checkpoint["decoder"])

        self.transform = transforms.Compose(
            [
                transforms.CenterCrop((self.config.H, self.config.W)),
                transforms.ToTensor(),
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
            ]
        )
        self.inv_transform = transforms.Compose(
            [
                transforms.Normalize([-1, -1, -1], [2, 2, 2]),
                transforms.ToPILImage(),
            ]
        )

    def _transform(self, image: Image.Image) -> torch.Tensor:
        return self.transform(image).unsqueeze(0)  # type: ignore

    def _inv_transform(self, image: torch.Tensor) -> Image.Image:
        return self.inv_transform(image.squeeze(0))  # type: ignore

    def encode(self, carrier: Image.Image, payload: torch.Tensor) -> Image.Image:
        image = self._transform(carrier).to(self.device)

        # message = torch.from_numpy(np.array(payload)).to(self.device)
        encoded = self.encoder(image, payload)
        return self._inv_transform(encoded.detach().cpu())

    def decode(self, carrier: Image.Image) -> np.ndarray:
        image = self._transform(carrier).to(self.device)
        decoded = self.decoder(image)

        return decoded.detach().cpu().numpy().round().clip(0, 1).astype(np.uint8)