#!/usr/bin/env python3
#
# This script is in the public domain
#
# Author: Johannes Schauer Marin Rodrigues <josch@mister-muffin.de>
#
# This script accepts a tarball on standard input and filters it according to
# the same rules used by dpkg --path-exclude and --path-include, using command
# line options of the same name. The result is then printed on standard output.
#
# A tool like this should be written in C but libarchive has issues:
# https://github.com/libarchive/libarchive/issues/587
# https://github.com/libarchive/libarchive/pull/1288/ (needs 3.4.1)
# Should these issues get fixed, then a good template is tarfilter.c in the
# examples directory of libarchive.
#
# We are not using Perl either, because Archive::Tar slurps the whole tarball
# into memory.
#
# We could also use Go but meh...
# https://stackoverflow.com/a/59542307/784669

import tarfile
import sys
import argparse
import fnmatch
import re


class PathFilterAction(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        items = getattr(namespace, "pathfilter", [])
        regex = re.compile(fnmatch.translate(values))
        items.append((self.dest, regex))
        setattr(namespace, "pathfilter", items)


class PaxFilterAction(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        items = getattr(namespace, "paxfilter", [])
        regex = re.compile(fnmatch.translate(values))
        items.append((self.dest, regex))
        setattr(namespace, "paxfilter", items)


def main():
    parser = argparse.ArgumentParser(
        description="""\
Filters a tarball on standard input by the same rules as the dpkg --path-exclude
and --path-include options and writes resulting tarball to standard output. See
dpkg(1) for information on how these two options work in detail. Since this is
meant for filtering tarballs storing a rootfs, notice that paths must be given
as /path and not as ./path even though they might be stored as such in the
tarball.

Similarly, filter out unwanted pax extended headers. This is useful in cases
where a tool only accepts certain xattr prefixes. For example tar2sqfs only
supports SCHILY.xattr.user.*, SCHILY.xattr.trusted.* and
SCHILY.xattr.security.* but not SCHILY.xattr.system.posix_acl_default.*.

Both types of options use Unix shell-style wildcards:

       * matches everything
       ? matches any single character
   [seq] matches any character in seq
  [!seq] matches any character not in seq

Thirdly, strip leading directory components off of tar members. Just as with
GNU tar --strip-components, tar members that have less or equal components in
their path are not passed through.
"""
    )
    parser.add_argument(
        "--path-exclude",
        metavar="pattern",
        action=PathFilterAction,
        help="Exclude path matching the given shell pattern.",
    )
    parser.add_argument(
        "--path-include",
        metavar="pattern",
        action=PathFilterAction,
        help="Re-include a pattern after a previous exclusion.",
    )
    parser.add_argument(
        "--pax-exclude",
        metavar="pattern",
        action=PaxFilterAction,
        help="Exclude pax header matching the given globbing pattern.",
    )
    parser.add_argument(
        "--pax-include",
        metavar="pattern",
        action=PaxFilterAction,
        help="Re-include a pax header after a previous exclusion.",
    )
    parser.add_argument(
        "--strip-components",
        metavar="number",
        type=int,
        help="Strip NUMBER leading components from file names",
    )
    args = parser.parse_args()
    if (
        not hasattr(args, "pathfilter")
        and not hasattr(args, "paxfilter")
        and not hasattr(args, "strip_components")
    ):
        from shutil import copyfileobj

        copyfileobj(sys.stdin.buffer, sys.stdout.buffer)
        exit()

    # same logic as in dpkg/src/filters.c/filter_should_skip()
    prefix_prog = re.compile(r"^([^*?[\\]*).*")

    def path_filter_should_skip(member):
        skip = False
        if not hasattr(args, "pathfilter"):
            return False
        for (t, r) in args.pathfilter:
            if r.match(member.name[1:]) is not None:
                if t == "path_include":
                    skip = False
                else:
                    skip = True
        if skip and (member.isdir() or member.issym()):
            for (t, r) in args.pathfilter:
                if t != "path_include":
                    continue
                prefix = prefix_prog.sub(r"\1", r.pattern)
                prefix = prefix.rstrip("/")
                if member.name[1:].startswith(prefix):
                    return False
        return skip

    def pax_filter_should_skip(header):
        if not hasattr(args, "paxfilter"):
            return False
        skip = False
        for (t, r) in args.paxfilter:
            if r.match(header) is None:
                continue
            if t == "pax_include":
                skip = False
            else:
                skip = True
        return skip

    # starting with Python 3.8, the default format became PAX_FORMAT, so this
    # is only for compatibility with older versions of Python 3
    with tarfile.open(fileobj=sys.stdin.buffer, mode="r|*") as in_tar, tarfile.open(
        fileobj=sys.stdout.buffer, mode="w|", format=tarfile.PAX_FORMAT
    ) as out_tar:
        for member in in_tar:
            if path_filter_should_skip(member):
                continue
            if args.strip_components:
                comps = member.name.split("/")
                if len(comps) <= args.strip_components:
                    continue
                member.name = "/".join(comps[args.strip_components :])
            member.pax_headers = {
                k: v
                for k, v in member.pax_headers.items()
                if not pax_filter_should_skip(k)
            }
            if member.isfile():
                with in_tar.extractfile(member) as file:
                    out_tar.addfile(member, file)
            else:
                out_tar.addfile(member)


if __name__ == "__main__":
    main()