Skip to content
Merged
33 changes: 30 additions & 3 deletions cpp2rust/converter/converter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,11 @@ void Converter::EmitRustStructOrUnion(clang::RecordDecl *decl) {
}
}

if (decl->isUnion()) {
EmitRustUnion(decl);
return;
}

// Derived traits
if (EmitsReprCForRecords()) {
StrCat("#[repr(C)]");
Expand All @@ -770,8 +775,7 @@ void Converter::EmitRustStructOrUnion(clang::RecordDecl *decl) {
auto access = clang::dyn_cast<clang::CXXRecordDecl>(decl)
? AccessSpecifierAsString(decl->getAccess())
: keyword::kPub;
StrCat(access, decl->isUnion() ? keyword::kUnion : keyword::kStruct,
GetRecordName(decl));
StrCat(access, keyword::kStruct, GetRecordName(decl));
{
PushBrace brace(*this);
for (auto *field : decl->fields()) {
Expand Down Expand Up @@ -817,6 +821,29 @@ void Converter::EmitRustStructOrUnion(clang::RecordDecl *decl) {
AddByteReprTrait(decl);
}

void Converter::EmitRustUnion(clang::RecordDecl *decl) {
StrCat("#[repr(C)]");
auto attrs = GetStructAttributes(decl);
Mapper::SetDerives(ctx_.getCanonicalTagType(decl),
std::vector<std::string>(attrs.begin(), attrs.end()));
StrCat("#[derive(");
for (auto *attr : attrs) {
StrCat(attr, ',');
}
StrCat(")]");

StrCat(keyword::kPub, keyword::kUnion, GetRecordName(decl));
{
PushBrace brace(*this);
for (auto *field : decl->fields()) {
VisitFieldDecl(field);
}
}

AddDefaultTrait(decl);
AddByteReprTrait(decl);
}

bool Converter::VisitCXXRecordDecl(clang::CXXRecordDecl *decl) {
if (clang::isa<clang::ClassTemplateSpecializationDecl>(decl)) {
materializeTemplateSpecialization(decl);
Expand Down Expand Up @@ -3878,7 +3905,7 @@ void Converter::AddOrdTrait(const clang::CXXRecordDecl *decl) {
ConvertOrdAndPartialOrdTraits(decl, methods[0]);
}

void Converter::AddCloneTrait(const clang::CXXRecordDecl *decl) {}
void Converter::AddCloneTrait(const clang::RecordDecl *decl) {}

void Converter::AddDropTrait(const clang::CXXRecordDecl *decl) {}

Expand Down
4 changes: 3 additions & 1 deletion cpp2rust/converter/converter.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ class Converter : public clang::RecursiveASTVisitor<Converter> {

virtual void EmitRustStructOrUnion(clang::RecordDecl *decl);

virtual void EmitRustUnion(clang::RecordDecl *decl);

virtual bool EmitsReprCForRecords() const { return true; }

virtual bool VisitCXXMethodDecl(clang::CXXMethodDecl *decl);
Expand Down Expand Up @@ -527,7 +529,7 @@ class Converter : public clang::RecursiveASTVisitor<Converter> {
std::string_view second_return,
std::string_view record_name);

virtual void AddCloneTrait(const clang::CXXRecordDecl *decl);
virtual void AddCloneTrait(const clang::RecordDecl *decl);

virtual void AddDropTrait(const clang::CXXRecordDecl *decl);

Expand Down
85 changes: 81 additions & 4 deletions cpp2rust/converter/models/converter_refcount.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -441,17 +441,28 @@ void ConverterRefCount::ConvertOrdAndPartialOrdTraits(
second_return, GetRecordName(decl));
}

void ConverterRefCount::AddCloneTrait(const clang::CXXRecordDecl *decl) {
if (decl->defaultedCopyConstructorIsDeleted()) {
void ConverterRefCount::AddCloneTrait(const clang::RecordDecl *decl) {
auto record_name = GetRecordName(decl);

if (decl->isUnion()) {
StrCat("impl Clone for", record_name);
PushBrace impl_brace(*this);
StrCat("fn clone(&self) -> Self");
PushBrace fn_brace(*this);
StrCat(record_name,
"{ __bytes: Rc::new(RefCell::new(self.__bytes.borrow().clone())) }");
return;
}

auto record_name = GetRecordName(decl);
auto *cxx = clang::dyn_cast<clang::CXXRecordDecl>(decl);
if (!cxx || cxx->defaultedCopyConstructorIsDeleted()) {
return;
}

StrCat(keyword::kImpl, "Clone for", record_name, '{');
StrCat("fn clone(&self) -> Self {");

for (auto ctor : decl->ctors()) {
for (auto ctor : cxx->ctors()) {
if (ctor->isCopyConstructor()) {
PushConversionKind push(*this, ConversionKind::FullRefCount);
ConvertCXXConstructorBody(ctor);
Expand All @@ -469,6 +480,39 @@ void ConverterRefCount::AddDefaultTrait(const clang::RecordDecl *decl) {
}

void ConverterRefCount::AddDefaultTraitForUnion(const clang::RecordDecl *decl) {
auto name = GetRecordName(decl);
StrCat("impl Default for", name);
PushBrace impl_brace(*this);
StrCat("fn default() -> Self");
PushBrace fn_brace(*this);
StrCat(std::format(
"{} {{ __bytes: Rc::new(RefCell::new(Box::from([0u8; {}]))) }}", name,
ctx_.getASTRecordLayout(decl).getSize().getQuantity()));
}

void ConverterRefCount::EmitRustUnion(clang::RecordDecl *decl) {
auto name = GetRecordName(decl);

auto attrs = GetStructAttributes(decl);
Mapper::SetDerives(ctx_.getCanonicalTagType(decl),
std::vector<std::string>(attrs.begin(), attrs.end()));

StrCat(std::format("pub struct {} {{ __bytes: Value<Box<[u8]>> }}", name));

StrCat("impl", name);
{
PushBrace impl_brace(*this);
for (auto *field : decl->fields()) {
StrCat(std::format(
"pub fn {}(&self) -> Ptr<{}> {{ (self.__bytes.as_pointer() "
"as Ptr<u8>).reinterpret_cast() }}",
GetNamedDeclAsString(field), Mapper::Map(field->getType())));
}
}

AddCloneTrait(decl);
AddDefaultTrait(decl);
AddByteReprTrait(decl);
}

void ConverterRefCount::AddDropTrait(const clang::CXXRecordDecl *decl) {
Expand Down Expand Up @@ -1441,6 +1485,28 @@ bool ConverterRefCount::VisitInitListExpr(clang::InitListExpr *expr) {
return false;
}

void ConverterRefCount::ConvertUnionMemberAccessor(clang::MemberExpr *expr) {
std::string str;
{
Buffer buf(*this);
PushExprKind push(*this, isLValue() ? ExprKind::LValue : ExprKind::RValue);
Converter::ConvertMemberExpr(expr);
str = std::move(buf).str();
}
str += "()";

if (isAddrOf()) {
StrCat(str);
computed_expr_type_ = ComputedExprType::Pointer;
return;
}
if (isLValue()) {
pending_deref_.set(str);
return;
}
StrCat(DerefPtrExpr(str, expr->getMemberDecl()->getType()));
}

bool ConverterRefCount::VisitMemberExpr(clang::MemberExpr *expr) {
auto *member = expr->getMemberDecl();
bool known = Mapper::Contains(expr);
Expand All @@ -1460,6 +1526,13 @@ bool ConverterRefCount::VisitMemberExpr(clang::MemberExpr *expr) {
return false;
}

if (auto *parent =
clang::dyn_cast<clang::RecordDecl>(member->getDeclContext());
parent && parent->isUnion() && clang::isa<clang::FieldDecl>(member)) {
ConvertUnionMemberAccessor(expr);
return false;
}

std::string str;
if (known) {
str = GetMappedAsString(expr);
Expand Down Expand Up @@ -1798,6 +1871,10 @@ std::vector<const char *>
ConverterRefCount::GetStructAttributes(const clang::RecordDecl *decl) {
std::vector<const char *> attrs;

if (decl->isUnion()) {
return attrs;
}

if (RecordDerivesDefault(decl)) {
attrs.emplace_back("Default");
}
Expand Down
6 changes: 5 additions & 1 deletion cpp2rust/converter/models/converter_refcount.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@ class ConverterRefCount final : public Converter {

bool VisitCXXRecordDecl(clang::CXXRecordDecl *decl) override;

void EmitRustUnion(clang::RecordDecl *decl) override;

bool EmitsReprCForRecords() const override { return false; }

void ConvertOrdAndPartialOrdTraits(const clang::CXXRecordDecl *decl,
const clang::FunctionDecl *op) override;

void AddCloneTrait(const clang::CXXRecordDecl *decl) override;
void AddCloneTrait(const clang::RecordDecl *decl) override;

void AddDropTrait(const clang::CXXRecordDecl *decl) override;

Expand Down Expand Up @@ -103,6 +105,8 @@ class ConverterRefCount final : public Converter {

bool VisitMemberExpr(clang::MemberExpr *expr) override;

void ConvertUnionMemberAccessor(clang::MemberExpr *expr);

bool VisitCXXNewExpr(clang::CXXNewExpr *expr) override;

bool VisitCXXDeleteExpr(clang::CXXDeleteExpr *expr) override;
Expand Down
Loading
Loading