From d9b56e8a00d74501583ccee82cadca6d15c5eda3 Mon Sep 17 00:00:00 2001 From: Honza Date: Thu, 6 Jan 2022 19:23:56 +0100 Subject: [PATCH] Fixed generation of unsigned variants of BAM and TM multipliers. Signed versions don't guarantee correct funcionality atm. --- .../arithmetic_circuits/multiplier_circuit.py | 20 ++- .../broken_array_multiplier.py | 158 ++++++++++-------- .../truncated_multiplier.py | 90 +++++----- tests/test_all.py | 12 +- 4 files changed, 157 insertions(+), 123 deletions(-) diff --git a/ariths_gen/core/arithmetic_circuits/multiplier_circuit.py b/ariths_gen/core/arithmetic_circuits/multiplier_circuit.py index 7ccfd99..59971c4 100644 --- a/ariths_gen/core/arithmetic_circuits/multiplier_circuit.py +++ b/ariths_gen/core/arithmetic_circuits/multiplier_circuit.py @@ -46,21 +46,27 @@ class MultiplierCircuit(ArithmeticCircuit): """ # To get the index of previous row's connecting adder and its generated pp if mult_type == "bam": + #TODO alter to be more compact ids_sum = 0 - for level in range(self.horizontal_cut, b_index): - # First pp level composed just from gates - if level == self.horizontal_cut: + for row in range(self.horizontal_cut + self.ommited_rows, b_index): + first_row_elem_id = self.vertical_cut-row if self.vertical_cut-row > 0 else 0 + # First pp row composed just from gates + if row == self.horizontal_cut + self.ommited_rows: # Minus one because the first component has index 0 instead of 1 - ids_sum += sum([1 for gate_pos in range(self.vertical_cut-level, self.N)])-1 + ids_sum += sum([1 for gate_pos in range(first_row_elem_id, self.N)])-1 + elif row == b_index-1: + ids_sum += sum([2 for gate_adder_pos in range(first_row_elem_id, self.N) if gate_adder_pos <= a_index+1]) else: - ids_sum += sum([2 for gate_adder_pos in range(self.vertical_cut-level, self.N) if gate_adder_pos <= a_index+1]) - index = ids_sum - + ids_sum += sum([2 for gate_adder_pos in range(first_row_elem_id, self.N)]) + # Index calculation should be redone, but it works even this way + index = ids_sum+2 if a_index == self.N-1 else ids_sum elif mult_type == "tm": index = ((b_index-self.truncation_cut-2) * ((self.N-self.truncation_cut)*2)) + ((self.N-self.truncation_cut-1)+2*(a_index-self.truncation_cut+2)) else: index = ((b_index-2) * ((self.N)*2)) + ((self.N-1)+2*(a_index+2)) + + # Get carry wire as input for the last adder in current row if a_index == self.N-1: index = index-2 diff --git a/ariths_gen/multi_bit_circuits/approximate_multipliers/broken_array_multiplier.py b/ariths_gen/multi_bit_circuits/approximate_multipliers/broken_array_multiplier.py index 3b09b25..6d078b7 100644 --- a/ariths_gen/multi_bit_circuits/approximate_multipliers/broken_array_multiplier.py +++ b/ariths_gen/multi_bit_circuits/approximate_multipliers/broken_array_multiplier.py @@ -37,48 +37,53 @@ class UnsignedBrokenArrayMultiplier(MultiplierCircuit): The design promises better area and power parameters in exchange for the loss of computation precision. The BAM design allows to save more partial product stage adders than truncated multiplier. - TODO - ``` - A3B0 A2B0 A1B0 A0B0 - │ │ │ │ │ │ │ │ - ┌▼─▼┐ ┌▼─▼┐ ┌▼─▼┐ ┌▼─▼┐ - │AND│ │AND│ │AND│ │AND│ - └┬──┘ └┬──┘ └┬──┘ └─┬─┘ - A3B1 │ A2B1 │ A1B1 │ A0B1 │ - ┌▼─▼┐ │ ┌▼─▼┐ │ ┌▼─▼┐ │ ┌▼─▼┐ │ - │AND│ │ │AND│ │ │AND│ │ │AND│ │ - └┬──┘ │ └┬──┘ │ └┬──┘ │ └┬──┘ │ - │ │ │ │ │ │ │ │ - ┌───▼┐ ┌▼──▼┐ ┌▼──▼┐ ┌▼──▼┐ │ - │ │ │ │ │ │ │ │ │ - ┌───────┤ HA │◄──┤ FA │◄──┤ FA │◄──┤ HA │ │ - │ │ │ │ │ │ │ │ │ │ - │ └┬───┘ └┬───┘ └┬───┘ └─┬──┘ │ - │ A3B2 │ A2B2 │ A1B2 │ A0B2 │ │ - │ ┌▼─▼┐ │ ┌▼─▼┐ │ ┌▼─▼┐ │ ┌▼─▼┐ │ │ - │ │AND│ │ │AND│ │ │AND│ │ │AND│ │ │ - │ └┬──┘ │ └┬──┘ │ └┬──┘ │ └┬──┘ │ │ - │ │ │ │ │ │ │ │ │ │ - ┌▼──▼┐ ┌▼──▼┐ ┌▼──▼┐ ┌▼──▼┐ │ │ - │ │ │ │ │ │ │ │ │ │ - ┌───────┤ FA │◄──┤ FA │◄──┤ FA │◄──┤ HA │ │ │ - │ │ │ │ │ │ │ │ │ │ │ - │ └┬───┘ └┬───┘ └┬───┘ └─┬──┘ │ │ - │ A3B3 │ A2B3 │ A1B3 │ A0B3 │ │ │ - │ ┌▼─▼┐ │ ┌▼─▼┐ │ ┌▼─▼┐ │ ┌▼─▼┐ │ │ │ - │ │AND│ │ │AND│ │ │AND│ │ │AND│ │ │ │ - │ └┬──┘ │ └┬──┘ │ └┬──┘ │ └┬──┘ │ │ │ - │ │ │ │ │ │ │ │ │ │ │ - ┌▼──▼┐ ┌▼──▼┐ ┌▼──▼┐ ┌▼──▼┐ │ │ │ - │ │ │ │ │ │ │ │ │ │ │ - ┌──────┤ FA │◄──┤ FA │◄──┤ FA │◄──┤ HA │ │ │ │ - │ │ │ │ │ │ │ │ │ │ │ │ - │ └─┬──┘ └─┬──┘ └─┬──┘ └─┬──┘ │ │ │ - │ │ │ │ │ │ │ │ - ▼ ▼ ▼ ▼ ▼ ▼ ▼ ▼ - P7 P6 P5 P4 P3 P2 P1 P0 ``` + VERTICAL CUT=4 + │ + A3B0 A2B0 A1B0 A0B0 + │ ┌───┐ ┌───┐ ┌───┐ ┌───┐ + │AND│ │AND│ │AND│ │AND│ + │ └───┘ └───┘ └───┘ └───┘ + + │ + A3B1 A2B1 A1B1 A0B1 + │ ┌───┐ ┌───┐ ┌───┐ ┌───┐ + │AND│ │AND│ │AND│ │AND│ + │ └───┘ └───┘ └───┘ └───┘ + + │ ┌────┐ ┌────┐ ┌────┐ ┌────┐ + │ │ │ │ │ │ │ │ + │ │ HA │ │ FA │ │ FA │ │ HA │ + │ │ │ │ │ │ │ │ + │ └────┘ └────┘ └────┘ └────┘ + + ─ ─ ─ ─ ─ ─ ─ ─ ─ ┼ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ HORIZONTAL CUT=2 + A3B2 A2B2 A1B2 A0B2 + ┌▼─▼┐ │ ┌───┐ ┌───┐ ┌───┐ + │AND│ │AND│ │AND│ │AND│ + └┬──┘ │ └───┘ └───┘ └───┘ + │ + │ │ ┌────┐ ┌────┐ ┌────┐ + │ │ │ │ │ │ │ + │ │ │ FA │ │ FA │ │ HA │ + │ │ │ │ │ │ │ + │ │ └────┘ └────┘ └────┘ + A3B3 │ A2B3 A1B3 A0B3 + ┌▼─▼┐ │ ┌▼─▼┐ │ ┌───┐ ┌───┐ + │AND│ │ │AND│ │AND│ │AND│ + └┬──┘ │ └┬──┘ │ └───┘ └───┘ + │ │ │ + ┌▼───┐ ┌▼──▼┐ │ ┌────┐ ┌────┐ + │ │ │ │ │ │ │ │ + ┌──────┤ HA │◄────┤ HA │ │ │ FA │ │ HA │ + │ │ │ │ │ │ │ │ │ + │ └──┬─┘ └──┬─┘ │ └────┘ └────┘ + │ │ │ + │ │ │ │ + ▼ ▼ ▼ ▼ ▼ ▼ ▼ ▼ + P7 P6 P5 │ P4=0 P3=0 P2=0 P1=0 P0=0 + ``` Description of the __init__ method. Args: @@ -90,45 +95,54 @@ class UnsignedBrokenArrayMultiplier(MultiplierCircuit): name (str, optional): Name of unsigned broken array multiplier. Defaults to "u_bam". """ def __init__(self, a: Bus, b: Bus, horizontal_cut: int = 0, vertical_cut: int = 0, prefix: str = "", name: str = "u_bam", **kwargs): - # Vertical cut should be greater or equal to horizontal cut - assert vertical_cut >= horizontal_cut - # NOTE: If horizontal/vertical cut is specified as 0 the final circuit is a simple array multiplier self.horizontal_cut = horizontal_cut self.vertical_cut = vertical_cut self.N = max(a.N, b.N) + # Horizontal cut level should be: 0 <= horizontal_cut < N + # Vertical cut level should be: horizontal_cut <= vertical_cut < 2*N + assert horizontal_cut < self.N + assert vertical_cut < 2*self.N + + # Vertical cut should be greater or equal to horizontal cut + assert vertical_cut >= horizontal_cut + super().__init__(a=a, b=b, prefix=prefix, name=name, out_N=self.N*2, **kwargs) # Bus sign extension in case buses have different lengths self.a.bus_extend(N=self.N, prefix=a.prefix) self.b.bus_extend(N=self.N, prefix=b.prefix) + self.ommited_rows = 0 # Gradual generation of partial products for b_multiplier_index in range(self.horizontal_cut, self.N): - for a_multiplicand_index in range(self.N): - # Skip generating the AND gates that should be ommited - if a_multiplicand_index+b_multiplier_index < (self.vertical_cut): - continue + # Number of elements that should be ommited in the current level based on vertical cut + pp_row_elems_to_skip = self.vertical_cut - b_multiplier_index if self.vertical_cut - b_multiplier_index > 0 else 0 + # Number of pp pairs present in the current row + pp_row_elems = self.N-pp_row_elems_to_skip if self.N-pp_row_elems_to_skip > 0 else 0 + self.ommited_rows += 1 if pp_row_elems == 0 else 0 + + for a_multiplicand_index in range((self.N-pp_row_elems), self.N): # AND gates generation for calculation of partial products obj_and = AndGate(self.a.get_wire(a_multiplicand_index), self.b.get_wire(b_multiplier_index), prefix=self.prefix+"_and"+str(a_multiplicand_index)+"_"+str(b_multiplier_index)) self.add_component(obj_and) - if b_multiplier_index != self.horizontal_cut: - if b_multiplier_index == self.horizontal_cut + 1: + if b_multiplier_index != self.horizontal_cut + self.ommited_rows: + if b_multiplier_index == self.horizontal_cut + self.ommited_rows + 1: previous_product = self.components[a_multiplicand_index + b_multiplier_index - self.vertical_cut].out else: previous_product = self.get_previous_partial_product(a_index=a_multiplicand_index, b_index=b_multiplier_index, mult_type="bam") # HA generation for first 1-bit adder in each row starting from the second one - if a_multiplicand_index == 0 or (self.vertical_cut-b_multiplier_index == a_multiplicand_index): + if a_multiplicand_index == 0 or self.vertical_cut-b_multiplier_index == a_multiplicand_index: obj_adder = HalfAdder(self.get_previous_component().out, previous_product, prefix=self.prefix+"_ha"+str(a_multiplicand_index)+"_"+str(b_multiplier_index)) self.add_component(obj_adder) # Product generation self.out.connect(b_multiplier_index, obj_adder.get_sum_wire()) # HA generation, last 1-bit adder in second row - elif a_multiplicand_index == self.N-1 and b_multiplier_index == self.horizontal_cut+1: + elif a_multiplicand_index == self.N-1 and b_multiplier_index == self.horizontal_cut+self.ommited_rows+1: obj_adder = HalfAdder(self.get_previous_component().out, self.get_previous_component(number=2).get_carry_wire(), prefix=self.prefix+"_ha"+str(a_multiplicand_index)+"_"+str(b_multiplier_index)) self.add_component(obj_adder) @@ -138,7 +152,7 @@ class UnsignedBrokenArrayMultiplier(MultiplierCircuit): self.add_component(obj_adder) # PRODUCT GENERATION - if (a_multiplicand_index == 0 and b_multiplier_index == self.horizontal_cut) or (self.horizontal_cut == self.N-1): + if (a_multiplicand_index == 0 and b_multiplier_index == self.horizontal_cut) or (self.horizontal_cut + self.ommited_rows == self.N-1): self.out.connect(a_multiplicand_index + b_multiplier_index, obj_and.out) # 1 bit multiplier case @@ -220,12 +234,19 @@ class SignedBrokenArrayMultiplier(MultiplierCircuit): name (str, optional): Name of signed broken array multiplier. Defaults to "s_bam". """ def __init__(self, a: Bus, b: Bus, horizontal_cut: int = 0, vertical_cut: int = 0, prefix: str = "", name: str = "s_bam", **kwargs): - #TODO - # NOTE: If horizontal/vertical break is specified as 0 the final circuit is a simple array multiplier + # NOTE: If horizontal/vertical cut is specified as 0 the final circuit is a simple array multiplier self.horizontal_cut = horizontal_cut self.vertical_cut = vertical_cut self.N = max(a.N, b.N) + # Horizontal cut level should be: 0 <= horizontal_cut < N + # Vertical cut level should be: horizontal_cut <= vertical_cut < 2*N + assert horizontal_cut < self.N + assert vertical_cut < 2*self.N + + # Vertical cut should be greater or equal to horizontal cut + assert vertical_cut >= horizontal_cut + super().__init__(a=a, b=b, prefix=prefix, name=name, out_N=self.N*2, signed=True, **kwargs) self.c_data_type = "int64_t" @@ -233,10 +254,16 @@ class SignedBrokenArrayMultiplier(MultiplierCircuit): self.a.bus_extend(N=self.N, prefix=a.prefix) self.b.bus_extend(N=self.N, prefix=b.prefix) - break_offsets = horizontal_cut + vertical_cut + self.ommited_rows = 0 # Gradual generation of partial products for b_multiplier_index in range(self.horizontal_cut, self.N): - for a_multiplicand_index in range(self.vertical_cut, self.N): + # Number of elements that should be ommited in the current level based on vertical cut + pp_row_elems_to_skip = self.vertical_cut - b_multiplier_index if self.vertical_cut - b_multiplier_index > 0 else 0 + # Number of pp pairs present in the current row + pp_row_elems = self.N-pp_row_elems_to_skip if self.N-pp_row_elems_to_skip > 0 else 0 + self.ommited_rows += 1 if pp_row_elems == 0 else 0 + + for a_multiplicand_index in range(self.N-pp_row_elems, self.N): # AND and NAND gates generation for calculation of partial products and sign extension if (b_multiplier_index == self.N-1 and a_multiplicand_index != self.N-1) or (b_multiplier_index != self.N-1 and a_multiplicand_index == self.N-1): obj_nand = NandGate(self.a.get_wire(a_multiplicand_index), self.b.get_wire(b_multiplier_index), prefix=self.prefix+"_nand"+str(a_multiplicand_index)+"_"+str(b_multiplier_index), parent_component=self) @@ -245,14 +272,14 @@ class SignedBrokenArrayMultiplier(MultiplierCircuit): obj_and = AndGate(self.a.get_wire(a_multiplicand_index), self.b.get_wire(b_multiplier_index), prefix=self.prefix+"_and"+str(a_multiplicand_index)+"_"+str(b_multiplier_index), parent_component=self) self.add_component(obj_and) - if b_multiplier_index != self.horizontal_cut and self.vertical_cut != self.N-1: - if b_multiplier_index == self.horizontal_cut + 1: - previous_product = self.components[a_multiplicand_index + b_multiplier_index - break_offsets].out + if b_multiplier_index != self.horizontal_cut + self.ommited_rows: + if b_multiplier_index == self.horizontal_cut + self.ommited_rows + 1: + previous_product = self.components[a_multiplicand_index + b_multiplier_index - self.vertical_cut].out else: - previous_product = self.get_previous_partial_product(a_index=a_multiplicand_index, b_index=b_multiplier_index, horizontal_cut=horizontal_cut, vertical_cut=vertical_cut) + previous_product = self.get_previous_partial_product(a_index=a_multiplicand_index, b_index=b_multiplier_index, mult_type="bam") # HA generation for first 1-bit adder in each row starting from the second one - if a_multiplicand_index == self.vertical_cut: + if a_multiplicand_index == 0 or self.vertical_cut-b_multiplier_index == a_multiplicand_index: obj_adder = HalfAdder(self.get_previous_component().out, previous_product, prefix=self.prefix+"_ha"+str(a_multiplicand_index)+"_"+str(b_multiplier_index)) self.add_component(obj_adder) # Product generation @@ -261,16 +288,15 @@ class SignedBrokenArrayMultiplier(MultiplierCircuit): # FA generation else: # Constant wire with value 1 used at the last FA in second row (as one of its inputs) for signed multiplication (based on Baugh Wooley algorithm) - if a_multiplicand_index == self.N-1 and b_multiplier_index == self.horizontal_cut+1: + if a_multiplicand_index == self.N-1 and b_multiplier_index == self.horizontal_cut+self.ommited_rows+1: previous_product = ConstantWireValue1() obj_adder = FullAdder(self.get_previous_component().out, previous_product, self.get_previous_component(number=2).get_carry_wire(), prefix=self.prefix+"_fa"+str(a_multiplicand_index)+"_"+str(b_multiplier_index)) self.add_component(obj_adder) # PRODUCT GENERATION - if (a_multiplicand_index == self.vertical_cut and b_multiplier_index == self.horizontal_cut) or (self.horizontal_cut == self.N-1 or self.vertical_cut == self.N-1): + if (a_multiplicand_index == 0 and b_multiplier_index == self.horizontal_cut) or (self.horizontal_cut + self.ommited_rows == self.N-1): self.out.connect(a_multiplicand_index + b_multiplier_index, obj_and.out) - # 1 bit multiplier case if a_multiplicand_index == self.N-1 and b_multiplier_index == self.N-1: obj_nor = NorGate(ConstantWireValue1(), self.get_previous_component().out, prefix=self.prefix+"_nor_zero_extend", parent_component=self) @@ -288,8 +314,8 @@ class SignedBrokenArrayMultiplier(MultiplierCircuit): self.out.connect(self.out.N-1, obj_xor.out) # Connecting the output bits generated from ommited cells to ground - if self.horizontal_cut >= self.N or self.vertical_cut >= self.N: + if self.horizontal_cut >= self.N or self.vertical_cut >= 2*self.N: [self.out.connect(out_id, ConstantWireValue0()) for out_id in range(self.out.N)] else: - for grounded_out_index in range(0, break_offsets): + for grounded_out_index in range(0, max(self.horizontal_cut, self.vertical_cut)): self.out.connect(grounded_out_index, ConstantWireValue0()) diff --git a/ariths_gen/multi_bit_circuits/approximate_multipliers/truncated_multiplier.py b/ariths_gen/multi_bit_circuits/approximate_multipliers/truncated_multiplier.py index 068d7ae..dd51dc0 100644 --- a/ariths_gen/multi_bit_circuits/approximate_multipliers/truncated_multiplier.py +++ b/ariths_gen/multi_bit_circuits/approximate_multipliers/truncated_multiplier.py @@ -27,7 +27,6 @@ from ariths_gen.multi_bit_circuits.multipliers import( SignedArrayMultiplier ) - class UnsignedTruncatedMultiplier(MultiplierCircuit): """Class representing unsigned truncated multiplier. @@ -35,49 +34,46 @@ class UnsignedTruncatedMultiplier(MultiplierCircuit): It is created by modifying an ordinary N-bit unsigned array multiplier by ignoring (truncating) some of the partial products. - The design promises better area and power parameters in exchange for the loss of computation precision. - - ```TODO - A3B0 A2B0 A1B0 A0B0 - │ │ │ │ │ │ │ │ - ┌▼─▼┐ ┌▼─▼┐ ┌▼─▼┐ ┌▼─▼┐ - │AND│ │AND│ │AND│ │AND│ - └┬──┘ └┬──┘ └┬──┘ └─┬─┘ - A3B1 │ A2B1 │ A1B1 │ A0B1 │ - ┌▼─▼┐ │ ┌▼─▼┐ │ ┌▼─▼┐ │ ┌▼─▼┐ │ - │AND│ │ │AND│ │ │AND│ │ │AND│ │ - └┬──┘ │ └┬──┘ │ └┬──┘ │ └┬──┘ │ - │ │ │ │ │ │ │ │ - ┌───▼┐ ┌▼──▼┐ ┌▼──▼┐ ┌▼──▼┐ │ - │ │ │ │ │ │ │ │ │ - ┌───────┤ HA │◄──┤ FA │◄──┤ FA │◄──┤ HA │ │ - │ │ │ │ │ │ │ │ │ │ - │ └┬───┘ └┬───┘ └┬───┘ └─┬──┘ │ - │ A3B2 │ A2B2 │ A1B2 │ A0B2 │ │ - │ ┌▼─▼┐ │ ┌▼─▼┐ │ ┌▼─▼┐ │ ┌▼─▼┐ │ │ - │ │AND│ │ │AND│ │ │AND│ │ │AND│ │ │ - │ └┬──┘ │ └┬──┘ │ └┬──┘ │ └┬──┘ │ │ - │ │ │ │ │ │ │ │ │ │ - ┌▼──▼┐ ┌▼──▼┐ ┌▼──▼┐ ┌▼──▼┐ │ │ - │ │ │ │ │ │ │ │ │ │ - ┌───────┤ FA │◄──┤ FA │◄──┤ FA │◄──┤ HA │ │ │ - │ │ │ │ │ │ │ │ │ │ │ - │ └┬───┘ └┬───┘ └┬───┘ └─┬──┘ │ │ - │ A3B3 │ A2B3 │ A1B3 │ A0B3 │ │ │ - │ ┌▼─▼┐ │ ┌▼─▼┐ │ ┌▼─▼┐ │ ┌▼─▼┐ │ │ │ - │ │AND│ │ │AND│ │ │AND│ │ │AND│ │ │ │ - │ └┬──┘ │ └┬──┘ │ └┬──┘ │ └┬──┘ │ │ │ - │ │ │ │ │ │ │ │ │ │ │ - ┌▼──▼┐ ┌▼──▼┐ ┌▼──▼┐ ┌▼──▼┐ │ │ │ - │ │ │ │ │ │ │ │ │ │ │ - ┌──────┤ FA │◄──┤ FA │◄──┤ FA │◄──┤ HA │ │ │ │ - │ │ │ │ │ │ │ │ │ │ │ │ - │ └─┬──┘ └─┬──┘ └─┬──┘ └─┬──┘ │ │ │ - │ │ │ │ │ │ │ │ - ▼ ▼ ▼ ▼ ▼ ▼ ▼ ▼ - P7 P6 P5 P4 P3 P2 P1 P0 + The design promises better area and power parameters in exchange for the loss of computation precision. + ``` + CUT=2 + A3B0 A2B0 │ A1B0 A0B0 + ┌───┐ ┌───┐ ┌───┐ ┌───┐ + │AND│ │AND│ │ │AND│ │AND│ + └───┘ └───┘ └───┘ └───┘ + ┌ ─ ─ ─ ┘ + A3B1 A2B1 A1B1 A0B1 + ┌───┐ ┌───┐ │ ┌───┐ ┌───┐ + │AND│ │AND│ │AND│ │AND│ + └───┘ └───┘ │ └───┘ └───┘ + ┌────┐ ┌────┐ ┌────┐ ┌────┐ + │ │ │ │ │ │ │ │ │ + │ HA │ │ FA │ │ FA │ │ HA │ + │ │ │ │ │ │ │ │ │ + └────┘ └────┘ └────┘ └────┘ + ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─┬ ─ ─ ─ ─ ┴─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ CUT=2 + A3B2 A2B2 A1B2 A0B2 + ┌▼─▼┐ ┌▼─▼┐ │ ┌───┐ ┌───┐ + │AND│ │AND│ │AND│ │AND│ + └┬──┘ └┬──┘ │ └───┘ └───┘ + │ │ ┌────┐ ┌────┐ + │ │ │ │ │ │ │ + │ ┌ ─ ┼─ ─ ─ ┘ │ FA │ │ HA │ + │ │ │ │ │ │ + │ │ │ └────┘ └────┘ + A3B3 │ A2B3 │ A1B3 A0B3 + ┌◄─►┐ │ ┌◄─►┐ │ │ ┌───┐ ┌───┐ + │AND│ │ │AND│ │ │AND│ │AND│ + └┬──┘ │ └┬──┘ │ │ └───┘ └───┘ + ┌───▼┐ ┌▼──▼┐ ┌┼───┐ ┌────┐ + │ │ │ │ │ ││ │ │ │ + ┌──────┤ HA │◄────┤ HA │ ││FA │ │ HA │ + │ │ │ │ │ │ ││ │ │ │ + │ └──┬─┘ └──┬─┘ └┼───┘ └────┘ + │ │ │ │ │ + ▼ ▼ ▼ ▼ ▼ ▼ ▼ ▼ + P7 P6 P5 │ P4 P3=0 P2=0 P1=0 P0=0 ``` - Description of the __init__ method. Args: @@ -92,6 +88,9 @@ class UnsignedTruncatedMultiplier(MultiplierCircuit): self.truncation_cut = truncation_cut self.N = max(a.N, b.N) + # Cut level should be: 0 <= truncation_cut < N + assert truncation_cut < self.N + super().__init__(a=a, b=b, prefix=prefix, name=name, out_N=self.N*2, **kwargs) # Bus sign extension in case buses have different lengths @@ -208,8 +207,11 @@ class SignedTruncatedMultiplier(MultiplierCircuit): def __init__(self, a: Bus, b: Bus, truncation_cut: int = 0, prefix: str = "", name: str = "s_tm", **kwargs): # NOTE: If truncation_cut is specified as 0 the final circuit is a simple array multiplier self.truncation_cut = truncation_cut - + self.N = max(a.N, b.N) + # Cut level should be: 0 <= truncation_cut < N + assert truncation_cut < self.N + super().__init__(a=a, b=b, prefix=prefix, name=name, out_N=self.N*2, signed=True, **kwargs) self.c_data_type = "int64_t" diff --git a/tests/test_all.py b/tests/test_all.py index 28f82df..4f8bd2a 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -62,14 +62,14 @@ def test_unsigned_approxmul(values = False): np.seterr(divide='ignore', invalid='ignore') WCRE = np.max(np.nan_to_num(abs(np.subtract(r, expected)) / expected)) - if isinstance(multiplier, UnsignedTruncatedMultiplier): + if isinstance(mul, UnsignedTruncatedMultiplier): # WCE_TM(n,k) = (2^k - 1) * (2^(n+1) - 2^k - 1) - expected_WCE = (2 ** multiplier.truncation_cut - 1) * (2 ** (multiplier.a.N+1) - 2 ** multiplier.truncation_cut - 1) - elif isinstance(multiplier, UnsignedBrokenArrayMultiplier): + expected_WCE = (2 ** mul.truncation_cut - 1) * (2 ** (mul.a.N+1) - 2 ** mul.truncation_cut - 1) + elif isinstance(mul, UnsignedBrokenArrayMultiplier): # WCE_BAM(n,h,v) = (2^n - 1) * {SUM_i0_to_h-1}(2^i) + 2^h * {SUM_i0_to_v-h-1}(2^(v-h) - 2^i) - sum_1 = sum([2**i for i in range(0, multiplier.horizontal_cut)]) - sum_2 = sum([2**(multiplier.vertical_cut-multiplier.horizontal_cut) - 2**i for i in range(0, multiplier.vertical_cut-multiplier.horizontal_cut)]) - expected_WCE = (2 ** multiplier.N - 1) * sum_1 + 2 ** multiplier.horizontal_cut * sum_2 + sum_1 = sum([2**i for i in range(0, mul.horizontal_cut)]) + sum_2 = sum([2**(mul.vertical_cut-mul.horizontal_cut) - 2**i for i in range(0, mul.vertical_cut-mul.horizontal_cut)]) + expected_WCE = (2 ** mul.N - 1) * sum_1 + 2 ** mul.horizontal_cut * sum_2 # Test expected result assert expected_WCE == WCE