mirror of
https://github.com/ehw-fit/ariths-gen.git
synced 2025-04-10 09:12:11 +01:00
Implementation of QuAd approximate adder
This commit is contained in:
parent
a4741db191
commit
a44b0638a1
@ -0,0 +1 @@
|
|||||||
|
from .quad import QuAdder
|
192
ariths_gen/multi_bit_circuits/approximate_adders/quad.py
Normal file
192
ariths_gen/multi_bit_circuits/approximate_adders/quad.py
Normal file
@ -0,0 +1,192 @@
|
|||||||
|
"""
|
||||||
|
Implementation of QuAdder
|
||||||
|
|
||||||
|
For more information, see:
|
||||||
|
M. A. Hanif, R. Hafiz, O. Hasan and M. Shafique, "QuAd: Design and analysis of Quality-area optimal Low-Latency approximate Adders," 2017 54th ACM/EDAC/IEEE Design Automation Conference (DAC), Austin, TX, USA, 2017, pp. 1-6, doi: 10.1145/3061639.3062306.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from ...wire_components import (
|
||||||
|
Wire,
|
||||||
|
ConstantWireValue0,
|
||||||
|
ConstantWireValue1,
|
||||||
|
Bus
|
||||||
|
)
|
||||||
|
from ariths_gen.core.arithmetic_circuits import (
|
||||||
|
ArithmeticCircuit,
|
||||||
|
MultiplierCircuit
|
||||||
|
)
|
||||||
|
from ariths_gen.one_bit_circuits.one_bit_components import (
|
||||||
|
HalfAdder,
|
||||||
|
FullAdder
|
||||||
|
)
|
||||||
|
from ariths_gen.one_bit_circuits.logic_gates import (
|
||||||
|
AndGate,
|
||||||
|
NandGate,
|
||||||
|
OrGate,
|
||||||
|
NorGate,
|
||||||
|
XorGate,
|
||||||
|
XnorGate,
|
||||||
|
NotGate
|
||||||
|
)
|
||||||
|
from ariths_gen.multi_bit_circuits.adders.ripple_carry_adder import UnsignedRippleCarryAdder
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
|
||||||
|
class QuAdder(ArithmeticCircuit):
|
||||||
|
"""
|
||||||
|
Implementation of QuAd
|
||||||
|
|
||||||
|
https://ieeexplore.ieee.org/document/8060326
|
||||||
|
|
||||||
|
The implementation is inspired by Matlab code from the authors of the paper:
|
||||||
|
```matlab
|
||||||
|
temp_count=1;
|
||||||
|
for iij=1:length(R_vect)
|
||||||
|
fprintf(fileID,['wire [' num2str(R_vect(iij)+P_vect(iij)) ':0] temp' num2str(temp_count) ';\n']);
|
||||||
|
temp_count=temp_count + 1;
|
||||||
|
end
|
||||||
|
|
||||||
|
temp_count=1;
|
||||||
|
for iiij=1:length(R_vect)
|
||||||
|
if (sum(R_vect(1:iiij))+P_vect(1)-1) == (sum(R_vect(1:iiij))+P_vect(1)-R_vect(iiij)-P_vect(iiij))
|
||||||
|
fprintf(fileID,['aassign temp' num2str(temp_count) '[' num2str(R_vect(iiij)+P_vect(iiij)) ':0] = in1[' num2str(sum(R_vect(1:iiij))+P_vect(1)-1) '] + in2[' num2str(sum(R_vect(1:iiij))+P_vect(1)-1) '];\n']);
|
||||||
|
else
|
||||||
|
disp(R_vect(1:iiij))
|
||||||
|
fprintf(fileID,['bassign temp' num2str(temp_count) '[' num2str(R_vect(iiij)+P_vect(iiij)) ':0] = in1[' num2str(sum(R_vect(1:iiij))+P_vect(1)-1) ':' num2str(sum(R_vect(1:iiij))+P_vect(1)-R_vect(iiij)-P_vect(iiij)) '] + in2[' num2str(sum(R_vect(1:iiij))+P_vect(1)-1) ':' num2str(sum(R_vect(1:iiij))+P_vect(1)-R_vect(iiij)-P_vect(iiij)) '];\n']);
|
||||||
|
end
|
||||||
|
temp_count=temp_count+1;
|
||||||
|
end
|
||||||
|
|
||||||
|
statement='};\n';
|
||||||
|
temp_count=1;
|
||||||
|
for iiij=1:length(R_vect)
|
||||||
|
if iiij ~= length(R_vect)
|
||||||
|
if (R_vect(iiij)==1)
|
||||||
|
statement = [', temp' num2str(temp_count) '[' num2str(R_vect(iiij)+P_vect(iiij)-1) '] ' statement];
|
||||||
|
else
|
||||||
|
statement = [', temp' num2str(temp_count) '[' num2str(R_vect(iiij)+P_vect(iiij)-1) ':' num2str(P_vect(iiij)) '] ' statement];
|
||||||
|
end
|
||||||
|
else
|
||||||
|
statement = ['assign res[' num2str(N) ':0] =' '{ temp' num2str(temp_count) '[' num2str(R_vect(iiij)+P_vect(iiij)) ':' num2str(P_vect(iiij)) '] ' statement];
|
||||||
|
end
|
||||||
|
temp_count=temp_count+1;
|
||||||
|
end
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def log(self, *args):
|
||||||
|
if self.use_log:
|
||||||
|
print(*args)
|
||||||
|
|
||||||
|
def __init__(self, a, b, R, P, prefix, name="quad", adder_type=None, use_log=False, **kwargs):
|
||||||
|
"""
|
||||||
|
:param a: Bus first input
|
||||||
|
:param b: Bus second input
|
||||||
|
:param R: list of integers, defines the resultant bits of all the sub-adders (the first index specifies the resultant bits of sub-adder 1 and so on)
|
||||||
|
:param P: list of integers, defines the prediction bits of all the sub-adders (again the first index specifies the prediction bits of sub-adder 1 and so on)
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not adder_type:
|
||||||
|
adder_type = UnsignedRippleCarryAdder
|
||||||
|
|
||||||
|
# Assumptions checks
|
||||||
|
assert len(R) == len(P), "R and P must have the same length"
|
||||||
|
print([P[i] < P[i-1] + R[i-1] for i in range(1, len(P))])
|
||||||
|
assert all([P[i] < P[i-1] + R[i-1] for i in range(1, len(P))]
|
||||||
|
), "Pi must be lower than Pi-1 + Ri-1"
|
||||||
|
assert sum(R) == a.N, "Sum of R must be equal to number of bits"
|
||||||
|
|
||||||
|
self.use_log = use_log
|
||||||
|
|
||||||
|
self.N = max(a.N, b.N)
|
||||||
|
super().__init__(a=a, b=b, prefix=prefix, name=name, out_N=self.N+1, **kwargs)
|
||||||
|
|
||||||
|
# Bus sign extension in case buses have different lengths
|
||||||
|
self.a.bus_extend(N=self.N, prefix=a.prefix)
|
||||||
|
self.b.bus_extend(N=self.N, prefix=b.prefix)
|
||||||
|
|
||||||
|
#warnings.warn("QuAdder is not tested yet")
|
||||||
|
|
||||||
|
# Connect all outputs to zero
|
||||||
|
for i in range(self.N+1):
|
||||||
|
self.out[i] = ConstantWireValue0()
|
||||||
|
|
||||||
|
# Declaration of temporary wires (just for debug purposes)
|
||||||
|
temp_count = 0
|
||||||
|
for iiij in range(0, len(R)):
|
||||||
|
self.log('wire [' + str(R[iiij]+P[iiij]) +
|
||||||
|
':0] temp' + str(temp_count) + ';')
|
||||||
|
temp_count = temp_count + 1
|
||||||
|
|
||||||
|
def bus_subconnect(out_bus, in_bus, out_indexes, in_indexes):
|
||||||
|
out_indexes = list(out_indexes)
|
||||||
|
in_indexes = list(in_indexes)
|
||||||
|
assert len(out_indexes) == len(in_indexes)
|
||||||
|
|
||||||
|
for i, j in zip(out_indexes, in_indexes):
|
||||||
|
if j >= in_bus.N:
|
||||||
|
out_bus[i] = ConstantWireValue0() # unsigned extension
|
||||||
|
else:
|
||||||
|
out_bus.connect(i, in_bus.get_wire(j)) # [i] = in_bus[j]
|
||||||
|
|
||||||
|
# Connection of adders
|
||||||
|
temp_count = 0
|
||||||
|
temp_bus = []
|
||||||
|
for iiij in range(0, len(R)):
|
||||||
|
# Former verilog output
|
||||||
|
self.log("assign temp{}[{}:0] = in1[{}:{}] + in2[{}:{}];".format(
|
||||||
|
temp_count,
|
||||||
|
R[iiij]+P[iiij],
|
||||||
|
sum(R[0:iiij + 1]) + P[0]-1,
|
||||||
|
sum(R[0:iiij + 1]) + P[0]-R[iiij]-P[iiij],
|
||||||
|
sum(R[0:iiij + 1]) + P[0]-1,
|
||||||
|
sum(R[0:iiij + 1]) + P[0]-R[iiij]-P[iiij]
|
||||||
|
))
|
||||||
|
|
||||||
|
a1 = Bus(f"{prefix}_temp_{temp_count}_a", R[iiij]+P[iiij])
|
||||||
|
b1 = Bus(f"{prefix}_temp_{temp_count}_b", R[iiij]+P[iiij])
|
||||||
|
|
||||||
|
bus_subconnect(b1, self.b,
|
||||||
|
range(R[iiij]+P[iiij]),
|
||||||
|
range(sum(R[0:iiij + 1])+P[0]-R[iiij]-P[iiij], sum(R[0:iiij + 1])+P[0]))
|
||||||
|
|
||||||
|
bus_subconnect(a1, self.a,
|
||||||
|
range(R[iiij]+P[iiij]),
|
||||||
|
range(sum(R[0:iiij + 1])+P[0]-R[iiij]-P[iiij], sum(R[0:iiij + 1])+P[0]))
|
||||||
|
|
||||||
|
temp_bus.append(self.add_component(
|
||||||
|
adder_type(a1, b1, prefix=f"{prefix}_add_{temp_count}")
|
||||||
|
|
||||||
|
))
|
||||||
|
temp_count = temp_count+1
|
||||||
|
|
||||||
|
# Final connection
|
||||||
|
temp_count = 0
|
||||||
|
statement = "}"
|
||||||
|
wire_id = 0
|
||||||
|
for iiij in range(0, len(R)):
|
||||||
|
if iiij != len(R) - 1:
|
||||||
|
if R[iiij] == 1:
|
||||||
|
statement = ', temp{}[{}]'.format(
|
||||||
|
temp_count, R[iiij]+P[iiij] - 1) + statement
|
||||||
|
else:
|
||||||
|
statement = ', temp{}[{}:{}]'.format(
|
||||||
|
temp_count, R[iiij]+P[iiij] - 1, P[iiij]) + statement
|
||||||
|
|
||||||
|
else:
|
||||||
|
statement = 'assign res[' + str(self.N) + ':0] =' + '{ temp' + str(
|
||||||
|
temp_count) + '[' + str(R[iiij]+P[iiij]) + ':' + str(P[iiij]) + '] ' + statement
|
||||||
|
|
||||||
|
self.log(statement)
|
||||||
|
for i in range(P[iiij], R[iiij]+P[iiij]):
|
||||||
|
self.log(temp_count, i, wire_id, temp_bus[temp_count].out[i])
|
||||||
|
self.out[wire_id] = temp_bus[temp_count].out[i]
|
||||||
|
wire_id += 1
|
||||||
|
|
||||||
|
temp_count = temp_count+1
|
||||||
|
|
||||||
|
# Last carry (MSB)
|
||||||
|
self.out[wire_id] = temp_bus[temp_count - 1].out[R[iiij]+P[iiij]]
|
79
generate_quad_lib.py
Normal file
79
generate_quad_lib.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
"""
|
||||||
|
This script generate the library of all possible QuAdders with N bits.
|
||||||
|
Note that the adders are not Pareto-optimal.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from ariths_gen.core.arithmetic_circuits.arithmetic_circuit import ArithmeticCircuit
|
||||||
|
from ariths_gen.core.arithmetic_circuits import GeneralCircuit
|
||||||
|
from ariths_gen.wire_components import Bus, Wire
|
||||||
|
from ariths_gen.multi_bit_circuits.adders import UnsignedRippleCarryAdder
|
||||||
|
from ariths_gen.multi_bit_circuits.approximate_adders import QuAdder
|
||||||
|
from ariths_gen.multi_bit_circuits.multipliers import UnsignedArrayMultiplier, UnsignedDaddaMultiplier
|
||||||
|
import os, sys
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
directory = f"lib_quad/lib_quad{N}"
|
||||||
|
os.makedirs(directory, exist_ok=True)
|
||||||
|
|
||||||
|
# generate the C code
|
||||||
|
cfile = open(f"{directory}/lib_quad_{N}.c", "w")
|
||||||
|
hfile = open(f"{directory}/lib_quad_{N}.h", "w")
|
||||||
|
hfile.write("#include <stdint.h>\n")
|
||||||
|
|
||||||
|
data = {}
|
||||||
|
|
||||||
|
# verilog code is zipped
|
||||||
|
import zipfile
|
||||||
|
vfile = zipfile.ZipFile(file=f"{directory}/lib_quad_{N}.zip", mode="w", compression=zipfile.ZIP_DEFLATED)
|
||||||
|
cnt = 0
|
||||||
|
N = 8
|
||||||
|
|
||||||
|
# up to 3 stages
|
||||||
|
for n in [1, 2, 3]:
|
||||||
|
Rall = list(itertools.product(range(1, N + 1), repeat=n))
|
||||||
|
for R in Rall:
|
||||||
|
# skip invalid R
|
||||||
|
if sum(R) != N:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for P in itertools.product(range(0, N + 1), repeat=n):
|
||||||
|
# test the condition from the paper
|
||||||
|
if not all([P[i] < P[i-1] + R[i-1] for i in range(1, len(P))]):
|
||||||
|
continue
|
||||||
|
print(cnt, R, P) # print the current configuration
|
||||||
|
|
||||||
|
prefix = f"quad_{N}"
|
||||||
|
name = "r_{}_p_{}".format("_".join([str(r) for r in R]), "_".join([str(p) for p in P]))
|
||||||
|
|
||||||
|
try:
|
||||||
|
c = QuAdder(Bus("a", N), Bus("b", N), R = R, P=P, name=name, prefix=prefix, use_log=False)
|
||||||
|
c.get_c_code_flat(file_object=cfile)
|
||||||
|
vf = vfile.open(f"{prefix}_{name}.v", "w")
|
||||||
|
# convert byte file vf to text file
|
||||||
|
import io
|
||||||
|
vt = io.TextIOWrapper(vf, encoding="utf-8")
|
||||||
|
c.get_v_code_flat(file_object=vt)
|
||||||
|
vt.close()
|
||||||
|
|
||||||
|
cfile.write("\n\n")
|
||||||
|
hfile.write(f"uint64_t {prefix}_{name}(uint64_t a, uint64_t b);")
|
||||||
|
|
||||||
|
data[f"{name}_{prefix}"] = {
|
||||||
|
"bw": N,
|
||||||
|
"cfun": f"{prefix}_{name}",
|
||||||
|
"verilog": f"{prefix}_{name}.v",
|
||||||
|
"verilog_entity": f"{prefix}_{name}",
|
||||||
|
"quad_r" : R,
|
||||||
|
"quad_p" : P,
|
||||||
|
}
|
||||||
|
cnt += 1
|
||||||
|
except IOError as e:
|
||||||
|
print(R, P, e)
|
||||||
|
|
||||||
|
# store the metadata
|
||||||
|
import json, gzip
|
||||||
|
json.dump(data, gzip.open(f"{directory}/lib_quad_{N}.json.gz", "wt"), indent=4)
|
26
tests/test_ax.py
Normal file
26
tests/test_ax.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
"""
|
||||||
|
Testing the QuAdder
|
||||||
|
"""
|
||||||
|
from ariths_gen.core.arithmetic_circuits.arithmetic_circuit import ArithmeticCircuit
|
||||||
|
from ariths_gen.core.arithmetic_circuits import GeneralCircuit
|
||||||
|
from ariths_gen.wire_components import Bus, Wire
|
||||||
|
from ariths_gen.multi_bit_circuits.adders import UnsignedRippleCarryAdder
|
||||||
|
from ariths_gen.multi_bit_circuits.approximate_adders import QuAdder
|
||||||
|
from ariths_gen.multi_bit_circuits.multipliers import UnsignedArrayMultiplier, UnsignedDaddaMultiplier
|
||||||
|
import os, sys
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
def test_quadder():
|
||||||
|
c = QuAdder(Bus("a", 8), Bus("b", 8), R = [4, 2, 2], P=[0, 2, 2], prefix="quad")
|
||||||
|
c.get_v_code_hier(file_object=sys.stdout)
|
||||||
|
|
||||||
|
x = np.arange(0, 256).reshape(-1, 1)
|
||||||
|
y = np.arange(0, 256).reshape(1, -1)
|
||||||
|
|
||||||
|
r = c(x, y)
|
||||||
|
r2 = x + y
|
||||||
|
|
||||||
|
assert np.abs(r - r2).max() == 64
|
||||||
|
np.testing.assert_equal(np.abs(r - r2).mean(), 7.5)
|
Loading…
x
Reference in New Issue
Block a user