Skip to content

DeepMIH2PNGinPNGTPAMI2022

stegobox.codec.DeepMIH2PNGinPNGTPAMI2022

Bases: BaseCodec

This steganography method is named: DeepMIHMULPNGinPNGTPAMI2022

  • Created by: Ruinan Ma
  • Created time: 2022/10/18

This is a PyTorch implementation of image steganography via deep learning, which is released in paper - DeepMIH: Deep Invertible Network for Multiple Image Hiding https://ieeexplore.ieee.org/document/9676416

Source code in stegobox/codec/deepmih_2pnginpng_tpami2022/deepmih_2pnginpng.py
class DeepMIH2PNGinPNGTPAMI2022(BaseCodec):
    """
    This steganography method is named: `DeepMIHMULPNGinPNGTPAMI2022`

    * Created by: Ruinan Ma
    * Created time: 2022/10/18

    This is a PyTorch implementation of image steganography via deep learning, which is
    released in paper - DeepMIH: Deep Invertible Network for Multiple Image Hiding
    https://ieeexplore.ieee.org/document/9676416
    """

    def __init__(
        self,
        para_pt_1: str = "ckpt/deepmih_2pnginpng_tpami2022/model_checkpoint_03000_1.pt",
        para_pt_2: str = "ckpt/deepmih_2pnginpng_tpami2022/model_checkpoint_03000_2.pt",
        para_pt_3: str = "ckpt/deepmih_2pnginpng_tpami2022/model_checkpoint_03000_3.pt",
        use_img_map: bool = False,
        verbose: bool = False,
    ) -> None:
        super().__init__()
        self.verbose = verbose
        self.use_img_map = use_img_map
        self.para_path1 = para_pt_1
        self.para_path2 = para_pt_2
        self.para_path3 = para_pt_3
        self.dwt = DWT()
        self.iwt = IWT()
        self.transform = transforms.Compose(
            [transforms.CenterCrop(c.cropsize_val), transforms.ToTensor()]
        )
        self.net1, self.net2, self.net3 = self.net_init()
        if verbose:
            if torch.cuda.is_available():
                print("Running DeepMIH with GPU.")
            else:
                print("Running DeepMIH with CPU.")

    def net_init(self):
        net1 = Model1()
        net2 = Model2()
        net3 = ImpMapBlock()
        net1.to(device)
        net2.to(device)
        net3.to(device)
        init_model(net1)
        init_model(net2)
        net1 = torch.nn.parallel.DataParallel(net1, device_ids=c.device_ids)
        net2 = torch.nn.parallel.DataParallel(net2, device_ids=c.device_ids)
        net3 = torch.nn.parallel.DataParallel(net3, device_ids=c.device_ids)
        load(self.para_path1, net1)
        load(self.para_path2, net2)
        load(self.para_path3, net3)
        net1.eval()
        net2.eval()
        net3.eval()
        if self.verbose:
            print("Pretrained models load successsfully.")
        return net1, net2, net3

    def encode(self, _, __) -> None:
        raise NotImplementedError("Use encode_multiple() instead.")

    def decode(self, _) -> None:
        raise NotImplementedError("Use decode_multiple() instead.")

    def encode_multiple(
        self, carrier: Image.Image, payload1: Image.Image, payload2: Image.Image
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Encode image with format png into image with format png.

        Args:
            carrier: cover image
            payload1: Payload secret image1
            payload2: Payload secret image2

        Returns:
            Encoded steganographic image with format torch.tensor
        """

        if self.verbose:
            print("Encoding...")

        """prepare"""
        # carrier.size()=payload1.size()=payload2.size()-->torch.Size([3, 256, 256])
        carrier = self.transform(carrier)  # type: ignore
        payload1 = self.transform(payload1)  # type: ignore
        payload2 = self.transform(payload2)  # type: ignore
        carrier = carrier.unsqueeze(dim=0)  # type: ignore
        payload1 = payload1.unsqueeze(dim=0)  # type: ignore
        payload2 = payload2.unsqueeze(dim=0)  # type: ignore
        # carrier_dwt.size()=payload1_dwt.size()=payload2_dwt.size()
        # --> torch.Size([1, 12, 128, 128])
        carrier_dwt = self.dwt(carrier).to(device)
        payload1_dwt = self.dwt(payload1).to(device)
        payload2_dwt = self.dwt(payload2).to(device)

        """forward1"""
        # input_dwt_1.size()-->torch.Size([1, 24, 128, 128])
        input_dwt_1 = torch.cat((carrier_dwt, payload1_dwt), dim=1)
        # output_dwt_1.size()-->torch.Size([1, 24, 128, 128])
        output_dwt_1 = self.net1(input_dwt_1)
        # output_steg_dwt_1.size()-->torch.Size([1, 12, 128, 128])
        output_steg_dwt_1 = output_dwt_1.narrow(1, 0, 4 * c.channels_in)
        # get steg1
        # output_steg_1.size()-->torch.Size([1, 3, 256, 256])
        output_steg_1 = self.iwt(output_steg_dwt_1).to(device)

        """forward2"""
        # img_map.size()-->torch.Size([1, 3, 256, 256])
        if self.use_img_map:
            img_map = self.net3(carrier, payload1, output_steg_1)
        else:
            img_map = torch.zeros(carrier.shape).to(device)  # type: ignore
        # imp_map_dwt.size()-->torch.Size([1, 12, 128, 128])
        imp_map_dwt = self.dwt(img_map)
        input_dwt_2 = torch.cat((output_steg_dwt_1, imp_map_dwt), dim=1)
        # input_dwt_2.size()-->torch.Size([1, 36, 128, 128])
        input_dwt_2 = torch.cat((input_dwt_2, payload2_dwt), dim=1)
        # output_dwt_2.size()-->torch.Size([1, 36, 128, 128])
        output_dwt_2 = self.net2(input_dwt_2)
        # output_steg_dwt_2.size()-->torch.Size([1, 12, 128, 128])
        output_steg_dwt_2 = output_dwt_2.narrow(1, 0, 4 * c.channels_in)
        # get steg2
        # output_steg_2.size()-->torch.Size([1, 3, 256, 256])
        output_steg_2 = self.iwt(output_steg_dwt_2).to(device)

        return output_steg_1, output_steg_2

    def decode_multiple(
        self, container1: Image.Image, container2: Image.Image
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Decode secret image from encoded steganographic image.

        Args:
            container1: Encoded carrier image1.
            container2: Encoded carrier image2.

        Returns:
            Two decoded images if decode is successful.
        """

        if self.verbose:
            print("Decoding...")

        """prepare"""
        container1, container2 = self.transform(container1), self.transform(container2)
        container1 = container1.unsqueeze(dim=0)  # type: ignore
        container2 = container2.unsqueeze(dim=0)  # type: ignore
        # container1.size()=container2.size()-->torch.Size([1, 12, 128, 128])
        container1_dwt = self.dwt(container1).to(device)
        container2_dwt = self.dwt(container2).to(device)
        # noise_shape.size()-->torch.Size([1, 24, 128, 128])
        noise_shape = torch.cat((container1_dwt, container2_dwt), dim=1)
        guass1 = gauss_noise(container1_dwt.shape)
        guass2 = gauss_noise(noise_shape.shape)

        """backward2"""
        # output_rev_dwt_2.size()-->torch.Size([1, 36, 128, 128])
        output_rev_dwt_2 = torch.cat((container2_dwt, guass2), dim=1)
        # rev_dwt_2.size()-->torch.Size([1, 36, 128, 128])
        rev_dwt_2 = self.net2(output_rev_dwt_2, rev=True)
        # rev_steg_dwt_1.size()=rev_sercet_dwt_1.size()-->torch.Size([1, 36, 128, 128])
        rev_steg_dwt_1 = rev_dwt_2.narrow(1, 0, 4 * c.channels_in)
        rev_sercet_dwt_2 = rev_dwt_2.narrow(1, 4 * c.channels_in, 4 * c.channels_in)
        # if you need middle container, you can export rev_steg_1.
        # rev_steg_1 = self.iwt(rev_steg_dwt_1).to(device)
        rev_sercet_2 = self.iwt(rev_sercet_dwt_2).to(device)

        """backward1"""
        # output_rev_dwt_1.size()-->torch.Size([1, 24, 128, 128])
        output_rev_dwt_1 = torch.cat((rev_steg_dwt_1, guass1), dim=1)
        # rev_dwt_1.size()-->torch.Size([1, 24, 128, 128])
        rev_dwt_1 = self.net1(output_rev_dwt_1, rev=True)
        rev_sercet_dwt_1 = rev_dwt_1.narrow(1, 4 * c.channels_in, 4 * c.channels_in)
        rev_secret_1 = self.iwt(rev_sercet_dwt_1).to(device)

        return rev_sercet_2, rev_secret_1

encode_multiple(carrier, payload1, payload2)

Encode image with format png into image with format png.

Parameters:

Name Type Description Default
carrier Image

cover image

required
payload1 Image

Payload secret image1

required
payload2 Image

Payload secret image2

required

Returns:

Type Description
tuple[Tensor, Tensor]

Encoded steganographic image with format torch.tensor

Source code in stegobox/codec/deepmih_2pnginpng_tpami2022/deepmih_2pnginpng.py
def encode_multiple(
    self, carrier: Image.Image, payload1: Image.Image, payload2: Image.Image
) -> tuple[torch.Tensor, torch.Tensor]:
    """Encode image with format png into image with format png.

    Args:
        carrier: cover image
        payload1: Payload secret image1
        payload2: Payload secret image2

    Returns:
        Encoded steganographic image with format torch.tensor
    """

    if self.verbose:
        print("Encoding...")

    """prepare"""
    # carrier.size()=payload1.size()=payload2.size()-->torch.Size([3, 256, 256])
    carrier = self.transform(carrier)  # type: ignore
    payload1 = self.transform(payload1)  # type: ignore
    payload2 = self.transform(payload2)  # type: ignore
    carrier = carrier.unsqueeze(dim=0)  # type: ignore
    payload1 = payload1.unsqueeze(dim=0)  # type: ignore
    payload2 = payload2.unsqueeze(dim=0)  # type: ignore
    # carrier_dwt.size()=payload1_dwt.size()=payload2_dwt.size()
    # --> torch.Size([1, 12, 128, 128])
    carrier_dwt = self.dwt(carrier).to(device)
    payload1_dwt = self.dwt(payload1).to(device)
    payload2_dwt = self.dwt(payload2).to(device)

    """forward1"""
    # input_dwt_1.size()-->torch.Size([1, 24, 128, 128])
    input_dwt_1 = torch.cat((carrier_dwt, payload1_dwt), dim=1)
    # output_dwt_1.size()-->torch.Size([1, 24, 128, 128])
    output_dwt_1 = self.net1(input_dwt_1)
    # output_steg_dwt_1.size()-->torch.Size([1, 12, 128, 128])
    output_steg_dwt_1 = output_dwt_1.narrow(1, 0, 4 * c.channels_in)
    # get steg1
    # output_steg_1.size()-->torch.Size([1, 3, 256, 256])
    output_steg_1 = self.iwt(output_steg_dwt_1).to(device)

    """forward2"""
    # img_map.size()-->torch.Size([1, 3, 256, 256])
    if self.use_img_map:
        img_map = self.net3(carrier, payload1, output_steg_1)
    else:
        img_map = torch.zeros(carrier.shape).to(device)  # type: ignore
    # imp_map_dwt.size()-->torch.Size([1, 12, 128, 128])
    imp_map_dwt = self.dwt(img_map)
    input_dwt_2 = torch.cat((output_steg_dwt_1, imp_map_dwt), dim=1)
    # input_dwt_2.size()-->torch.Size([1, 36, 128, 128])
    input_dwt_2 = torch.cat((input_dwt_2, payload2_dwt), dim=1)
    # output_dwt_2.size()-->torch.Size([1, 36, 128, 128])
    output_dwt_2 = self.net2(input_dwt_2)
    # output_steg_dwt_2.size()-->torch.Size([1, 12, 128, 128])
    output_steg_dwt_2 = output_dwt_2.narrow(1, 0, 4 * c.channels_in)
    # get steg2
    # output_steg_2.size()-->torch.Size([1, 3, 256, 256])
    output_steg_2 = self.iwt(output_steg_dwt_2).to(device)

    return output_steg_1, output_steg_2

decode_multiple(container1, container2)

Decode secret image from encoded steganographic image.

Parameters:

Name Type Description Default
container1 Image

Encoded carrier image1.

required
container2 Image

Encoded carrier image2.

required

Returns:

Type Description
tuple[Tensor, Tensor]

Two decoded images if decode is successful.

Source code in stegobox/codec/deepmih_2pnginpng_tpami2022/deepmih_2pnginpng.py
def decode_multiple(
    self, container1: Image.Image, container2: Image.Image
) -> tuple[torch.Tensor, torch.Tensor]:
    """Decode secret image from encoded steganographic image.

    Args:
        container1: Encoded carrier image1.
        container2: Encoded carrier image2.

    Returns:
        Two decoded images if decode is successful.
    """

    if self.verbose:
        print("Decoding...")

    """prepare"""
    container1, container2 = self.transform(container1), self.transform(container2)
    container1 = container1.unsqueeze(dim=0)  # type: ignore
    container2 = container2.unsqueeze(dim=0)  # type: ignore
    # container1.size()=container2.size()-->torch.Size([1, 12, 128, 128])
    container1_dwt = self.dwt(container1).to(device)
    container2_dwt = self.dwt(container2).to(device)
    # noise_shape.size()-->torch.Size([1, 24, 128, 128])
    noise_shape = torch.cat((container1_dwt, container2_dwt), dim=1)
    guass1 = gauss_noise(container1_dwt.shape)
    guass2 = gauss_noise(noise_shape.shape)

    """backward2"""
    # output_rev_dwt_2.size()-->torch.Size([1, 36, 128, 128])
    output_rev_dwt_2 = torch.cat((container2_dwt, guass2), dim=1)
    # rev_dwt_2.size()-->torch.Size([1, 36, 128, 128])
    rev_dwt_2 = self.net2(output_rev_dwt_2, rev=True)
    # rev_steg_dwt_1.size()=rev_sercet_dwt_1.size()-->torch.Size([1, 36, 128, 128])
    rev_steg_dwt_1 = rev_dwt_2.narrow(1, 0, 4 * c.channels_in)
    rev_sercet_dwt_2 = rev_dwt_2.narrow(1, 4 * c.channels_in, 4 * c.channels_in)
    # if you need middle container, you can export rev_steg_1.
    # rev_steg_1 = self.iwt(rev_steg_dwt_1).to(device)
    rev_sercet_2 = self.iwt(rev_sercet_dwt_2).to(device)

    """backward1"""
    # output_rev_dwt_1.size()-->torch.Size([1, 24, 128, 128])
    output_rev_dwt_1 = torch.cat((rev_steg_dwt_1, guass1), dim=1)
    # rev_dwt_1.size()-->torch.Size([1, 24, 128, 128])
    rev_dwt_1 = self.net1(output_rev_dwt_1, rev=True)
    rev_sercet_dwt_1 = rev_dwt_1.narrow(1, 4 * c.channels_in, 4 * c.channels_in)
    rev_secret_1 = self.iwt(rev_sercet_dwt_1).to(device)

    return rev_sercet_2, rev_secret_1