Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework enums in rust. #6098

Merged
merged 20 commits into from
Oct 19, 2020
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
244 changes: 133 additions & 111 deletions samples/monster_generated.rs

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions samples/sample_binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,9 @@ fn main() {

println!("The FlatBuffer was successfully created and accessed!");
}

#[cfg(test)]
#[test]
fn test_main() {
main()
}
243 changes: 147 additions & 96 deletions src/idl_gen_rust.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,14 @@ std::string AddUnwrapIfRequired(std::string s, bool required) {
}
}

bool IsBitFlagsEnum(const EnumDef &enum_def) {
return enum_def.attributes.Lookup("bit_flags");
}
bool IsBitFlagsEnum(const FieldDef &field) {
EnumDef* ed = field.value.type.enum_def;
return ed && IsBitFlagsEnum(*ed);
}

namespace rust {

class RustGenerator : public BaseGenerator {
Expand Down Expand Up @@ -217,7 +225,10 @@ class RustGenerator : public BaseGenerator {
// the future. as a result, we proactively block these out as reserved
// words.
"follow", "push", "size", "alignment", "to_little_endian",
"from_little_endian", nullptr
"from_little_endian", nullptr,

// used by Enum constants
"ENUM_MAX", "ENUM_MIN", "ENUM_VALUES",
};
for (auto kw = keywords; *kw; kw++) keywords_.insert(*kw);
}
Expand All @@ -230,9 +241,17 @@ class RustGenerator : public BaseGenerator {

assert(!cur_name_space_);

bool import_bitflags = false;
for (auto it = parser_.enums_.vec.begin(); it != parser_.enums_.vec.end();
++it) {
if (IsBitFlagsEnum(**it)) {
import_bitflags = true;
break;
}
}
// Generate imports for the global scope in case no namespace is used
// in the schema file.
GenNamespaceImports(0);
GenNamespaceImports(0, import_bitflags);
code_ += "";

// Generate all code in their namespaces, once, because Rust does not
Expand Down Expand Up @@ -512,85 +531,157 @@ class RustGenerator : public BaseGenerator {

std::string GetEnumValUse(const EnumDef &enum_def,
const EnumVal &enum_val) const {
return Name(enum_def) + "::" + Name(enum_val);
const std::string val = IsBitFlagsEnum(enum_def) ?
MakeUpper(MakeSnakeCase(Name(enum_val))) : Name(enum_val);
return Name(enum_def) + "::" + val;
}


void ForAllEnumValues(const EnumDef &enum_def,
std::function<void(const EnumVal&)> cb) {
for (auto it = enum_def.Vals().begin(); it != enum_def.Vals().end(); ++it) {
const auto &ev = **it;
code_.SetValue("VARIANT", Name(ev));
code_.SetValue("SSC_VARIANT", MakeUpper(MakeSnakeCase(Name(ev))));
code_.SetValue("VALUE", enum_def.ToString(ev));
cb(ev);
}
}
void ForAllEnumValues(const EnumDef &enum_def, std::function<void()> cb) {
ForAllEnumValues(enum_def, [&](const EnumVal& unused) { cb(); });
}
// Generate an enum declaration,
// an enum string lookup table,
// an enum match function,
// and an enum array of values
void GenEnum(const EnumDef &enum_def) {
code_.SetValue("ENUM_NAME", Name(enum_def));
code_.SetValue("BASE_TYPE", GetEnumTypeForDecl(enum_def.underlying_type));

GenComment(enum_def.doc_comment);
code_ += "#[allow(non_camel_case_types)]";
code_ += "#[repr({{BASE_TYPE}})]";
code_ +=
"#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]";
code_ += "pub enum " + Name(enum_def) + " {";

for (auto it = enum_def.Vals().begin(); it != enum_def.Vals().end(); ++it) {
const auto &ev = **it;

GenComment(ev.doc_comment, " ");
code_.SetValue("KEY", Name(ev));
code_.SetValue("VALUE", enum_def.ToString(ev));
code_ += " {{KEY}} = {{VALUE}},";
}
code_.SetValue("ENUM_NAME_SNAKE", MakeSnakeCase(Name(enum_def)));
code_.SetValue("ENUM_NAME_CAPS", MakeUpper(MakeSnakeCase(Name(enum_def))));
const EnumVal *minv = enum_def.MinValue();
const EnumVal *maxv = enum_def.MaxValue();
FLATBUFFERS_ASSERT(minv && maxv);
code_.SetValue("ENUM_MIN_BASE_VALUE", enum_def.ToString(*minv));
code_.SetValue("ENUM_MAX_BASE_VALUE", enum_def.ToString(*maxv));

code_ += "";
if (IsBitFlagsEnum(enum_def)) {
// Defer to the convenient and canonical bitflags crate.
code_ += "bitflags::bitflags! {";
GenComment(enum_def.doc_comment);
code_ += " pub struct {{ENUM_NAME}}: {{BASE_TYPE}} {";
ForAllEnumValues(enum_def, [&]{
code_ += " const {{SSC_VARIANT}} = {{VALUE}};";
});
code_ += " }";
code_ += "}";
code_ += "";
// Generate Follow and Push so we can serialize and stuff.
code_ += "impl<'a> flatbuffers::Follow<'a> for {{ENUM_NAME}} {";
code_ += " type Inner = Self;";
code_ += " #[inline]";
code_ += " fn follow(buf: &'a [u8], loc: usize) -> Self::Inner {";
code_ += " let bits = flatbuffers::read_scalar_at::<{{BASE_TYPE}}>(buf, loc);";
code_ += " unsafe { Self::from_bits_unchecked(bits) }";
code_ += " }";
code_ += "}";
code_ += "";
code_ += "impl flatbuffers::Push for {{ENUM_NAME}} {";
code_ += " type Output = {{ENUM_NAME}};";
code_ += " #[inline]";
code_ += " fn push(&self, dst: &mut [u8], _rest: &[u8]) {";
code_ += " flatbuffers::emplace_scalar::<{{BASE_TYPE}}>"
"(dst, self.bits());";
code_ += " }";
code_ += "}";
code_ += "";
code_ += "impl flatbuffers::EndianScalar for {{ENUM_NAME}} {";
code_ += " #[inline]";
code_ += " fn to_little_endian(self) -> Self {";
code_ += " let bits = {{BASE_TYPE}}::to_le(self.bits());";
code_ += " unsafe { Self::from_bits_unchecked(bits) }";
code_ += " }";
code_ += " #[inline]";
code_ += " fn from_little_endian(self) -> Self {";
code_ += " let bits = {{BASE_TYPE}}::from_le(self.bits());";
code_ += " unsafe { Self::from_bits_unchecked(bits) }";
code_ += " }";
code_ += "}";
code_ += "";
return;
}

GenComment(enum_def.doc_comment);
code_ +=
"#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]";
code_ += "pub struct {{ENUM_NAME}}(pub {{BASE_TYPE}});";
code_ += "#[allow(non_upper_case_globals)]";
code_ += "impl {{ENUM_NAME}} {";
code_ += " pub const ENUM_MIN: {{BASE_TYPE}} = {{ENUM_MIN_BASE_VALUE}};";
code_ += " pub const ENUM_MAX: {{BASE_TYPE}} = {{ENUM_MAX_BASE_VALUE}};";

ForAllEnumValues(enum_def, [&](const EnumVal &ev){
GenComment(ev.doc_comment, " ");
code_ += " pub const {{VARIANT}}: Self = Self({{VALUE}});";
});
code_ += " pub const ENUM_VALUES: &'static [Self] = &[";
ForAllEnumValues(enum_def, [&]{
code_ += " Self::{{VARIANT}},";
});
code_ += " ];";
code_ += " /// Returns the variant's name or \"\" if unknown.";
code_ += " pub fn variant_name(self) -> &'static str {";
code_ += " match self {";
ForAllEnumValues(enum_def, [&]{
code_ += " Self::{{VARIANT}} => \"{{VARIANT}}\",";
});
code_ += " _ => \"\",";
code_ += " }";
code_ += " }";
code_ += "}";
code_ += "";

code_.SetValue("ENUM_NAME", Name(enum_def));
code_.SetValue("ENUM_NAME_SNAKE", MakeSnakeCase(Name(enum_def)));
code_.SetValue("ENUM_NAME_CAPS", MakeUpper(MakeSnakeCase(Name(enum_def))));
code_.SetValue("ENUM_MIN_BASE_VALUE", enum_def.ToString(*minv));
code_.SetValue("ENUM_MAX_BASE_VALUE", enum_def.ToString(*maxv));
// Generate Debug. Unknown variants are printed like "<UNKNOWN 42>".
code_ += "impl std::fmt::Debug for {{ENUM_NAME}} {";
code_ += " fn fmt(&self, f: &mut std::fmt::Formatter) ->"
" std::fmt::Result {";
code_ += " let name = self.variant_name();";
code_ += " if name.is_empty() {";
code_ += " f.write_fmt(format_args!(\"<UNKNOWN {:?}>\", self.0))";
code_ += " } else {";
code_ += " f.write_str(name)";
code_ += " }";
code_ += " }";
code_ += "}";

// Generate enum constants, and impls for Follow, EndianScalar, and Push.
code_ += "pub const ENUM_MIN_{{ENUM_NAME_CAPS}}: {{BASE_TYPE}} = \\";
code_ += "{{ENUM_MIN_BASE_VALUE}};";
code_ += "pub const ENUM_MAX_{{ENUM_NAME_CAPS}}: {{BASE_TYPE}} = \\";
code_ += "{{ENUM_MAX_BASE_VALUE}};";
code_ += "";
// Generate Follow and Push so we can serialize and stuff.
code_ += "impl<'a> flatbuffers::Follow<'a> for {{ENUM_NAME}} {";
code_ += " type Inner = Self;";
code_ += " #[inline]";
code_ += " fn follow(buf: &'a [u8], loc: usize) -> Self::Inner {";
code_ += " flatbuffers::read_scalar_at::<Self>(buf, loc)";
code_ += " Self(flatbuffers::read_scalar_at::<{{BASE_TYPE}}>(buf, loc))";
code_ += " }";
code_ += "}";
code_ += "";
code_ += "impl flatbuffers::Push for {{ENUM_NAME}} {";
code_ += " type Output = {{ENUM_NAME}};";
code_ += " #[inline]";
code_ += " fn push(&self, dst: &mut [u8], _rest: &[u8]) {";
code_ += " flatbuffers::emplace_scalar::<{{BASE_TYPE}}>"
"(dst, self.0);";
code_ += " }";
code_ += "}";
code_ += "";
code_ += "impl flatbuffers::EndianScalar for {{ENUM_NAME}} {";
code_ += " #[inline]";
code_ += " fn to_little_endian(self) -> Self {";
code_ += " let n = {{BASE_TYPE}}::to_le(self as {{BASE_TYPE}});";
code_ += " let p = &n as *const {{BASE_TYPE}} as *const {{ENUM_NAME}};";
code_ += " unsafe { *p }";
code_ += " Self({{BASE_TYPE}}::to_le(self.0))";
code_ += " }";
code_ += " #[inline]";
code_ += " fn from_little_endian(self) -> Self {";
code_ += " let n = {{BASE_TYPE}}::from_le(self as {{BASE_TYPE}});";
code_ += " let p = &n as *const {{BASE_TYPE}} as *const {{ENUM_NAME}};";
code_ += " unsafe { *p }";
code_ += " Self({{BASE_TYPE}}::from_le(self.0))";
code_ += " }";
code_ += "}";
code_ += "";
code_ += "impl flatbuffers::Push for {{ENUM_NAME}} {";
code_ += " type Output = {{ENUM_NAME}};";
code_ += " #[inline]";
code_ += " fn push(&self, dst: &mut [u8], _rest: &[u8]) {";
code_ +=
" flatbuffers::emplace_scalar::<{{ENUM_NAME}}>"
"(dst, *self);";
code_ += " }";
code_ += "}";
code_ += "";

// Generate an array of all enumeration values.
auto num_fields = NumToString(enum_def.size());
Expand All @@ -606,49 +697,6 @@ class RustGenerator : public BaseGenerator {
code_ += "];";
code_ += "";

// Generate a string table for enum values.
// Problem is, if values are very sparse that could generate really big
// tables. Ideally in that case we generate a map lookup instead, but for
// the moment we simply don't output a table at all.
auto range = enum_def.Distance();
// Average distance between values above which we consider a table
// "too sparse". Change at will.
static const uint64_t kMaxSparseness = 5;
if (range / static_cast<uint64_t>(enum_def.size()) < kMaxSparseness) {
code_ += "#[allow(non_camel_case_types)]";
code_ += "pub const ENUM_NAMES_{{ENUM_NAME_CAPS}}: [&str; " +
NumToString(range + 1) + "] = [";

auto val = enum_def.Vals().front();
for (auto it = enum_def.Vals().begin(); it != enum_def.Vals().end();
++it) {
auto ev = *it;
for (auto k = enum_def.Distance(val, ev); k > 1; --k) {
code_ += " \"\",";
}
val = ev;
auto suffix = *it != enum_def.Vals().back() ? "," : "";
code_ += " \"" + Name(*ev) + "\"" + suffix;
}
code_ += "];";
code_ += "";

code_ +=
"pub fn enum_name_{{ENUM_NAME_SNAKE}}(e: {{ENUM_NAME}}) -> "
"&'static str {";

code_ += " let index = e as {{BASE_TYPE}}\\";
if (enum_def.MinValue()->IsNonZero()) {
auto vals = GetEnumValUse(enum_def, *enum_def.MinValue());
code_ += " - " + vals + " as {{BASE_TYPE}}\\";
}
code_ += ";";

code_ += " ENUM_NAMES_{{ENUM_NAME_CAPS}}[index as usize]";
code_ += "}";
code_ += "";
}

if (enum_def.is_union) {
// Generate tyoesafe offset(s) for unions
code_.SetValue("NAME", Name(enum_def));
Expand Down Expand Up @@ -1029,9 +1077,8 @@ class RustGenerator : public BaseGenerator {
}
case ftUnionKey:
case ftEnumKey: {
const auto underlying_typname = GetTypeBasic(type); //<- never used
const auto typname = WrapInNameSpace(*type.enum_def);
const auto default_value = GetDefaultScalarValue(field);
const std::string typname = WrapInNameSpace(*type.enum_def);
const std::string default_value = GetDefaultScalarValue(field);
if (field.optional) {
return "self._tab.get::<" + typname + ">(" + offset_name + ", None)";
} else {
Expand Down Expand Up @@ -1770,7 +1817,10 @@ class RustGenerator : public BaseGenerator {
code_ += "";
}

void GenNamespaceImports(const int white_spaces) {
void GenNamespaceImports(const int white_spaces, bool bitflags=false) {
if (white_spaces == 0) {
code_ += "#![allow(unused_imports, dead_code)]";
}
std::string indent = std::string(white_spaces, ' ');
code_ += "";
if (!parser_.opts.generate_all) {
Expand All @@ -1788,6 +1838,7 @@ class RustGenerator : public BaseGenerator {
code_ += indent + "use std::cmp::Ordering;";
code_ += "";
code_ += indent + "extern crate flatbuffers;";
if (bitflags) code_ += indent + "extern crate bitflags;";
code_ += indent + "use self::flatbuffers::EndianScalar;";
}

Expand Down
1 change: 1 addition & 0 deletions tests/include_test/include_test1_generated.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// automatically generated by the FlatBuffers compiler, do not modify


#![allow(unused_imports, dead_code)]

use crate::include_test2_generated::*;
use std::mem;
Expand Down