diff --git a/ariths_gen/core/cgp_circuit.py b/ariths_gen/core/cgp_circuit.py index a453a7a..8cd381a 100644 --- a/ariths_gen/core/cgp_circuit.py +++ b/ariths_gen/core/cgp_circuit.py @@ -21,25 +21,37 @@ import re class UnsignedCGPCircuit(GeneralCircuit): """Unsigned circuit variant that loads CGP code and is able to export it to C/verilog/Blif/CGP.""" - def __init__(self, code: str, input_widths: list, prefix: str = "", name: str = "cgp", **kwargs): + def __init__(self, code: str = "", input_widths: list = None, inputs: list = None, prefix: str = "", name: str = "cgp", **kwargs): cgp_prefix, cgp_core, cgp_outputs = re.match( r"{(.*)}(.*)\(([^()]+)\)", code).groups() c_in, c_out, c_rows, c_cols, c_ni, c_no, c_lback = map( int, cgp_prefix.split(",")) + + assert inputs is not None or input_widths is not None, "Either inputs or input_widths must be provided" - assert sum( - input_widths) == c_in, f"CGP input width {c_in} doesn't match input_widths {input_widths}" + if inputs: + assert input_widths is None, "Only one of inputs or input_widths must be provided" - inputs = [Bus(N=bw, prefix=f"input_{chr(i)}") - for i, bw in enumerate(input_widths, start=0x61)] + input_widths =[i.N for i in inputs] + assert sum(input_widths) == c_in, f"CGP input width {c_in} doesn't match inputs {inputs_widths}" - # Assign each Bus object in self.inputs to a named attribute of self - for bus in inputs: - # Here, bus.prefix is 'input_a', 'input_b', etc. - # We strip 'input_' and use the remaining part (e.g., 'a', 'b') to create the attribute name - attr_name = bus.prefix.replace('input_', '') - setattr(self, attr_name, bus) + + + else: + + assert sum( + input_widths) == c_in, f"CGP input width {c_in} doesn't match input_widths {input_widths}" + + inputs = [Bus(N=bw, prefix=f"input_{chr(i)}") + for i, bw in enumerate(input_widths, start=0x61)] + + # Assign each Bus object in self.inputs to a named attribute of self + for bus in inputs: + # Here, bus.prefix is 'input_a', 'input_b', etc. + # We strip 'input_' and use the remaining part (e.g., 'a', 'b') to create the attribute name + attr_name = bus.prefix.replace('input_', '') + setattr(self, attr_name, bus) # Adding values to the list self.vals = {} @@ -51,6 +63,10 @@ class UnsignedCGPCircuit(GeneralCircuit): j += 1 super().__init__(prefix=prefix, name=name, out_N=c_out, inputs=inputs, **kwargs) + + if not code: + return # only for getting the name + cgp_core = cgp_core.split(")(") i = 0 @@ -125,6 +141,6 @@ class UnsignedCGPCircuit(GeneralCircuit): class SignedCGPCircuit(UnsignedCGPCircuit): """Signed circuit variant that loads CGP code and is able to export it to C/verilog/Blif/CGP.""" - def __init__(self, code: str, input_widths: list, prefix: str = "", name: str = "cgp", **kwargs): - super().__init__(code=code, input_widths=input_widths, prefix=prefix, name=name, signed=True, **kwargs) + def __init__(self, code: str, input_widths: list = None, inputs: list=None, prefix: str = "", name: str = "cgp", **kwargs): + super().__init__(code=code, input_widths=input_widths, inputs=inputs, prefix=prefix, name=name, signed=True, **kwargs) self.c_data_type = "int64_t"