diff --git a/ariths_gen/core/arithmetic_circuits/general_circuit.py b/ariths_gen/core/arithmetic_circuits/general_circuit.py index 94b7e89..20b380d 100644 --- a/ariths_gen/core/arithmetic_circuits/general_circuit.py +++ b/ariths_gen/core/arithmetic_circuits/general_circuit.py @@ -30,16 +30,22 @@ class GeneralCircuit(): self.inputs = [] input_names = "abcdefghijklmnopqrstuvwxyz" # This should be enough.. assert len(input_names) >= len(inputs) - for i, input_bus in enumerate(inputs): + for i, input in enumerate(inputs): attr_name = input_names[i] - full_prefix = f"{self.prefix}_{input_bus.prefix}" if self.inner_component else f"{input_bus.prefix}" - bus = Bus(prefix=full_prefix, wires_list=input_bus.bus) - setattr(self, attr_name, bus) - self.inputs.append(bus) + full_prefix = f"{self.prefix}_{input.prefix}" if self.inner_component else f"{input.prefix}" + if isinstance(input, Bus): + bus = Bus(prefix=full_prefix, wires_list=input.bus) + setattr(self, attr_name, bus) + self.inputs.append(bus) + + # If the input bus is an output bus, connect it + if input.is_output_bus(): + getattr(self, attr_name).connect_bus(connecting_bus=input) + else: + wire = Wire(name=input.name, prefix=full_prefix) + setattr(self, attr_name, wire) + self.inputs.append(wire) - # If the input bus is an output bus, connect it - if input_bus.is_output_bus(): - getattr(self, attr_name).connect_bus(connecting_bus=input_bus) else: self.inputs = inputs diff --git a/tests/test_all.py b/tests/test_all.py index cb4f85a..d194669 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -337,7 +337,7 @@ def test_mac(): def test_direct(): class err_circuit(GeneralCircuit): - def __init__(self, prefix: str = "", name: str = "adder", inner_component: bool = True, a: Bus = Bus(), b: Bus = Bus()): + def __init__(self, a: Bus = Bus(), b: Bus = Bus(), prefix: str = "", name: str = "adder", inner_component: bool = False): super().__init__(prefix=prefix, name=name, out_N=(a.N + 1), inner_component=inner_component, inputs=[a, b]) self.N = 1 self.prefix = prefix @@ -389,4 +389,6 @@ if __name__ == "__main__": test_unsigned_add() test_signed_add() test_mac() + test_direct() + test_wire_as_bus() print("Python tests were successful!")