diff --git a/NEWS.md b/NEWS.md index 316b50d0d..ad78e01f9 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,8 @@ # dbplyr (development version) +* The class of remote sources now includes all S4 class names, not just + the first (#918). + * `db_explain()` now works for Oracle (@thomashulst, #1353). * Database errors now show the generated SQL, which hopefully will make it diff --git a/R/src_dbi.R b/R/src_dbi.R index 57e6c193f..978be8439 100644 --- a/R/src_dbi.R +++ b/R/src_dbi.R @@ -123,17 +123,20 @@ src_dbi <- function(con, auto_disconnect = FALSE) { disco <- db_disconnector(con, quiet = is_true(auto_disconnect)) # nocov } - subclass <- paste0("src_", class(con)[[1]]) - structure( list( con = con, disco = disco ), - class = c(subclass, "src_dbi", "src_sql", "src") + class = connection_s3_class(con) ) } +connection_s3_class <- function(con) { + subclass <- setdiff(methods::is(con), methods::extends("DBIConnection")) + c(paste0("src_", subclass), "src_dbi", "src_sql", "src") +} + methods::setOldClass(c("src_dbi", "src_sql", "src")) # nocov start diff --git a/tests/testthat/test-src_dbi.R b/tests/testthat/test-src_dbi.R index e7b8c23a3..337808779 100644 --- a/tests/testthat/test-src_dbi.R +++ b/tests/testthat/test-src_dbi.R @@ -4,3 +4,21 @@ test_that("tbl and src classes include connection class", { expect_true(inherits(mf, "tbl_SQLiteConnection")) expect_true(inherits(mf$src, "src_SQLiteConnection")) }) + +test_that("generates S3 class based on S4 class name", { + con <- DBI::dbConnect(RSQLite::SQLite(), ":memory:") + expect_equal( + connection_s3_class(con), + c("src_SQLiteConnection", "src_dbi", "src_sql", "src") + ) + + on.exit(removeClass("Foo2")) + on.exit(removeClass("Foo1")) + + Foo1 <- setClass("Foo1", contains = "DBIConnection") + Foo2 <- setClass("Foo2", contains = "Foo1") + expect_equal( + connection_s3_class(Foo2()), + c("src_Foo2", "src_Foo1", "src_dbi", "src_sql", "src") + ) +})