diff --git a/jvm/src/test/scala/scala/xml/XMLTest.scala b/jvm/src/test/scala/scala/xml/XMLTest.scala
index 43dd182e2..5008bd6e6 100644
--- a/jvm/src/test/scala/scala/xml/XMLTest.scala
+++ b/jvm/src/test/scala/scala/xml/XMLTest.scala
@@ -550,6 +550,24 @@ class XMLTestJVM {
     XML.loadString(broken)
   }
 
+  @UnitTest
+  def issueSI9047AttributeFromSingleChildElementWorks: Unit = {
+    val x = <x><a b='1'/></x>
+
+    val b = x \ "a" \ "@b"
+
+    assertEquals(List("1"), b map (_.text))
+  }
+
+  @UnitTest
+  def issueSI9047AttributeMultipleChildElementsWorks: Unit = {
+    val x = <x><a b='1'/><a b='2'/></x>
+
+    val b = x \ "a" \ "@b"
+
+    assertEquals(List("1", "2"), b map (_.text))
+  }
+
   @UnitTest
   def nodeSeqNs: Unit = {
     val x = {
diff --git a/shared/src/main/scala/scala/xml/NodeSeq.scala b/shared/src/main/scala/scala/xml/NodeSeq.scala
index c498279e9..df43f2553 100644
--- a/shared/src/main/scala/scala/xml/NodeSeq.scala
+++ b/shared/src/main/scala/scala/xml/NodeSeq.scala
@@ -95,8 +95,7 @@ abstract class NodeSeq extends AbstractSeq[Node] with immutable.Seq[Node] with S
   def \(that: String): NodeSeq = {
     def fail = throw new IllegalArgumentException(that)
     def atResult = {
-      lazy val y = this(0)
-      val attr =
+      this flatMap (y => (
         if (that.length == 1) fail
         else if (that(1) == '{') {
           val i = that indexOf '}'
@@ -105,10 +104,9 @@ abstract class NodeSeq extends AbstractSeq[Node] with immutable.Seq[Node] with S
           if (uri == "" || key == "") fail
           else y.attribute(uri, key)
         } else y.attribute(that drop 1)
-
-      attr match {
-        case Some(x) => Group(x)
-        case _       => NodeSeq.Empty
+      ).getOrElse(Nil)) match {
+        case NodeSeq.Empty => NodeSeq.Empty
+        case x => Group(x)
       }
     }
 
@@ -118,7 +116,7 @@ abstract class NodeSeq extends AbstractSeq[Node] with immutable.Seq[Node] with S
     that match {
       case ""                                        => fail
       case "_"                                       => makeSeq(!_.isAtom)
-      case _ if (that(0) == '@' && this.length == 1) => atResult
+      case _ if that(0) == '@'                       => atResult
       case _                                         => makeSeq(_.label == that)
     }
   }
diff --git a/shared/src/test/scala/scala/xml/AttributeTest.scala b/shared/src/test/scala/scala/xml/AttributeTest.scala
index 8943a0a00..b7a0d3115 100644
--- a/shared/src/test/scala/scala/xml/AttributeTest.scala
+++ b/shared/src/test/scala/scala/xml/AttributeTest.scala
@@ -147,9 +147,9 @@ class AttributeTest {
     val b = xml \ "b"
     assertEquals(2, b.length)
     assertEquals(NodeSeq.fromSeq(Seq(<b bar="1"/>, <b bar="2"/>)), b)
-    val barFail = b \ "@bar"
+    val barAttributesDirect = b \ "@bar"
     val barList =  b.map(_ \ "@bar")
-    assertEquals(NodeSeq.Empty, barFail)
+    assertEquals(Group(Seq(Text("1"), Text("2"))), barAttributesDirect)
     assertEquals(List(Group(Seq(Text("1"))), Group(Seq(Text("2")))), barList)
   }