flowy/filter.py
2014-06-26 08:47:04 +02:00

196 lines
6.4 KiB
Python

from copy import deepcopy
from copy import copy
from statement import Field
import time
import profiler
class NoMatch(Exception):
pass
class Filter(object):
def __init__(self,rules, records, br_mask, nbranches):
self.rules = rules
self.records = records
self.br_mask = br_mask
# print "The filter has just been initiated"
# Iteration of the filter happens at the splitter function go()
# In this teration function, each of the records is being matched
# against all of the conditions in each of the filters, and based
# on what condition it matches, it is assigned an appropriate
# branch mask. I.e., if Branch A has a srcport=443, then the record
# that matches this requirement gets a mask of [True, False], else
# if Branch B's filter is matched, then a mask of [False, True] is
# assigned.
def __iter__(self):
print "Started filtering"
# start = time.clock()
# print "Fitlering time started at:", start
for record in self.records:
self.br_mask.reset()
try:
for rule in self.rules:
rule_result = rule.match(record)
self.br_mask.mask(rule.branch_mask, rule_result)
except NoMatch:
continue
branches = self.br_mask.final_result()
if True in branches:
yield record, branches
# print "Finished filtering"
# time_elapsed = (time.clock() - start)
# print "Filtering required:", time_elapsed
#class Field(object):
# def __init__(self, name):
# self.name = name
# def __repr__(self):
# return "Field('%s')"%self.name
# Implementation of a self-defined deepcopy function that operates
# for the simple data types.
def deep_copy(org):
out = dict().fromkeys(org)
for k,v in org.iteritems():
try:
out[k] = v.copy() # dicts, sets
except AttributeError:
try:
out[k] = v[:] # lists, tuples, strings, unicode
except TypeError:
out[k] = v # ints
return out
class BranchMask(object):
def __init__(self, branch_masks, pseudo_branches, n_real_branches):
self.masks = branch_masks
# self.orig_mask = deepcopy(branch_masks)
self.orig_mask = deep_copy(branch_masks)
# self.pseudo_branches = deepcopy(pseudo_branches)
self.pseudo_branches = deep_copy(pseudo_branches)
self.n_real_branches = n_real_branches
def reset(self):
# self.masks = deepcopy(self.orig_mask)
self.masks = deep_copy(self.orig_mask)
#self.masks = copy(self.orig_mask)
# self.masks = self.orig_mask
def mask(self, sub_branches, result):
for br, sub_br, NOT in sub_branches:
res = not result if NOT else result
if sub_br == 0:
self.masks[br][sub_br] = self.masks[br][sub_br] and res
else:
self.masks[br][sub_br] = self.masks[br][sub_br] or res
def final_result(self):
final_mask = {}
for br, mask in self.masks.iteritems():
final_mask[br] = True if False not in mask else False
result = []
for id in xrange(self.n_real_branches):
try:
result.append(final_mask[id])
except KeyError:
gr_res = True
for or_group in self.pseudo_branches[id]:
res = False
for ref in or_group:
if ref[1]:
res = res or not final_mask[ref[0]]
else:
res = res or final_mask[ref[0]]
gr_res = gr_res and res
result.append(gr_res)
return result
class Rule(object):
def __init__(self, branch_mask, operation, args):
self.operation = operation
self.args = args
self.branch_mask = branch_mask
# This match operation is used at both the filtering and group-filering
# stages, since group-filter also relies on this Rule class.
def match(self, record):
args = []
for arg in self.args:
if type(arg) is Field: # Used both at filterin and group-filtering stages
args.append(getattr(record, arg.name))
elif type(arg) is Rule: # Used only at the group-fitlering stage
args.append(arg.match(record))
else: # Used at both stages. The actual argument numbers, i.e., port 80
args.append(arg)
return self.operation(*args)
class PreSplitRule(Rule):
def match(self,record):
result = Rule.match(self,record)
if not result:
raise NoMatch()
class GroupFilter(object):
def __init__(self, rules, records, branch_name, groups_table, index):
self.rules = rules
self.records = records
self.branch_name = branch_name
self.index = index
self.groups_table = groups_table
self.record_reader = RecordReader(self.groups_table)
def go(self):
count = 0
for record in self.records:
for or_rules in self.rules:
matched = False
for rule in or_rules:
if rule.match(record):
matched = True
break
if not matched:
break
if matched:
record.rec_id = count
count += 1
self.index.add(record)
self.groups_table.append(record)
print "Finished filtering groups for branch " + self.branch_name
self.groups_table.flush()
def __iter__(self):
for rec in self.record_reader:
yield rec
class AcceptGroupFilter(GroupFilter):
def __init__(self, records, branch_name, groups_table, index):
GroupFilter.__init__(self, None, records, branch_name, groups_table,
index)
def go(self):
count = 0
for record in self.records:
record.rec_id = count
count += 1
self.index.add(record)
self.groups_table.append(record)
print "Finished filtering groups for branch " + self.branch_name
self.groups_table.flush()